341 lines
13 KiB
Python
341 lines
13 KiB
Python
import json
|
|
import logging
|
|
import base64
|
|
from datetime import datetime
|
|
from typing import Dict, Any, Optional
|
|
import functions_framework
|
|
import vertexai
|
|
from vertexai.vision_models import MultiModalEmbeddingModel, Image as VertexImage
|
|
from google.cloud import firestore
|
|
from google.cloud import storage
|
|
from qdrant_client import QdrantClient
|
|
from qdrant_client.http import models
|
|
from qdrant_client.http.models import Distance, VectorParams, PointStruct
|
|
import numpy as np
|
|
from PIL import Image
|
|
import io
|
|
import os
|
|
import uuid
|
|
|
|
# Configure logging
|
|
logging.basicConfig(level=logging.INFO)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Initialize Vertex AI
|
|
PROJECT_ID = os.environ.get('GOOGLE_CLOUD_PROJECT') or os.environ.get('GCP_PROJECT')
|
|
LOCATION = os.environ.get('VERTEX_AI_LOCATION', 'us-central1')
|
|
|
|
if PROJECT_ID:
|
|
vertexai.init(project=PROJECT_ID, location=LOCATION)
|
|
logger.info(f"Initialized Vertex AI with project {PROJECT_ID} in location {LOCATION}")
|
|
else:
|
|
logger.error("PROJECT_ID not found in environment variables")
|
|
|
|
# Get Firestore configuration from environment variables
|
|
FIRESTORE_PROJECT_ID = os.environ.get('FIRESTORE_PROJECT_ID')
|
|
FIRESTORE_DATABASE_NAME = os.environ.get('FIRESTORE_DATABASE_NAME', '(default)')
|
|
|
|
# Initialize Firestore client with correct project and database
|
|
if FIRESTORE_PROJECT_ID:
|
|
firestore_client = firestore.Client(project=FIRESTORE_PROJECT_ID, database=FIRESTORE_DATABASE_NAME)
|
|
else:
|
|
firestore_client = firestore.Client(database=FIRESTORE_DATABASE_NAME)
|
|
|
|
storage_client = storage.Client()
|
|
|
|
# Get bucket name from environment variable
|
|
GCS_BUCKET_NAME = os.environ.get('GCS_BUCKET_NAME', 'contoso-images')
|
|
|
|
# Initialize Qdrant
|
|
QDRANT_HOST = os.environ.get('QDRANT_HOST', 'localhost')
|
|
QDRANT_PORT = int(os.environ.get('QDRANT_PORT', '6333'))
|
|
QDRANT_API_KEY = os.environ.get('QDRANT_API_KEY')
|
|
QDRANT_COLLECTION = os.environ.get('QDRANT_COLLECTION', 'image_vectors')
|
|
QDRANT_HTTPS = os.environ.get('QDRANT_HTTPS', 'false').lower() == 'true'
|
|
|
|
try:
|
|
qdrant_client = QdrantClient(
|
|
host=QDRANT_HOST,
|
|
port=QDRANT_PORT,
|
|
api_key=QDRANT_API_KEY,
|
|
https=QDRANT_HTTPS
|
|
)
|
|
|
|
# Ensure collection exists
|
|
try:
|
|
collections = qdrant_client.get_collections()
|
|
collection_names = [col.name for col in collections.collections]
|
|
|
|
if QDRANT_COLLECTION not in collection_names:
|
|
logger.info(f"Creating Qdrant collection: {QDRANT_COLLECTION}")
|
|
qdrant_client.create_collection(
|
|
collection_name=QDRANT_COLLECTION,
|
|
vectors_config=VectorParams(
|
|
size=1408, # Vertex AI multimodal embedding size
|
|
distance=Distance.COSINE
|
|
)
|
|
)
|
|
logger.info(f"Collection {QDRANT_COLLECTION} created successfully")
|
|
else:
|
|
logger.info(f"Collection {QDRANT_COLLECTION} already exists")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error ensuring Qdrant collection exists: {e}")
|
|
qdrant_client = None
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to initialize Qdrant client: {e}")
|
|
qdrant_client = None
|
|
|
|
if not qdrant_client:
|
|
logger.warning("Qdrant not configured, embeddings will not be stored")
|
|
|
|
@functions_framework.cloud_event
|
|
def process_image_embedding(cloud_event):
|
|
"""
|
|
Cloud Function triggered by Pub/Sub to process image embeddings
|
|
"""
|
|
try:
|
|
# Decode the Pub/Sub message
|
|
message_data = base64.b64decode(cloud_event.data["message"]["data"]).decode('utf-8')
|
|
message = json.loads(message_data)
|
|
|
|
logger.info(f"Processing image embedding task: {message}")
|
|
|
|
# Extract task data
|
|
image_id = message.get('image_id')
|
|
storage_path = message.get('storage_path')
|
|
team_id = message.get('team_id')
|
|
retry_count = message.get('retry_count', 0)
|
|
|
|
if not all([image_id, storage_path, team_id]):
|
|
logger.error(f"Missing required fields in message: {message}")
|
|
return
|
|
|
|
# Update image status to processing
|
|
update_image_status(image_id, 'processing', retry_count)
|
|
|
|
# Process the image
|
|
success = process_image(image_id, storage_path, team_id, retry_count)
|
|
|
|
if success:
|
|
logger.info(f"Successfully processed image {image_id}")
|
|
update_image_status(image_id, 'success', retry_count)
|
|
else:
|
|
logger.error(f"Failed to process image {image_id}")
|
|
update_image_status(image_id, 'failed', retry_count, "Processing failed")
|
|
|
|
# Retry logic is handled by Pub/Sub retry policy
|
|
# The function will be retried automatically up to 3 times
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in process_image_embedding: {e}")
|
|
# Extract image_id if possible for status update
|
|
try:
|
|
message_data = base64.b64decode(cloud_event.data["message"]["data"]).decode('utf-8')
|
|
message = json.loads(message_data)
|
|
image_id = message.get('image_id')
|
|
retry_count = message.get('retry_count', 0)
|
|
if image_id:
|
|
update_image_status(image_id, 'failed', retry_count, str(e))
|
|
except:
|
|
pass
|
|
raise e
|
|
|
|
def process_image(image_id: str, storage_path: str, team_id: str, retry_count: int) -> bool:
|
|
"""
|
|
Process a single image to generate embeddings
|
|
|
|
Args:
|
|
image_id: The ID of the image to process
|
|
storage_path: The GCS path of the image
|
|
team_id: The team ID that owns the image
|
|
retry_count: Current retry count
|
|
|
|
Returns:
|
|
True if processing was successful, False otherwise
|
|
"""
|
|
try:
|
|
# Download image from Cloud Storage
|
|
# The storage_path is just the path within the bucket (e.g., "team_id/filename.jpg")
|
|
# The bucket name comes from the environment variable
|
|
bucket = storage_client.bucket(GCS_BUCKET_NAME)
|
|
blob = bucket.blob(storage_path)
|
|
|
|
if not blob.exists():
|
|
logger.error(f"Image not found in storage: {GCS_BUCKET_NAME}/{storage_path}")
|
|
return False
|
|
|
|
# Download image data
|
|
image_data = blob.download_as_bytes()
|
|
|
|
# Generate embeddings using Vertex AI
|
|
embeddings = generate_image_embeddings(image_data)
|
|
|
|
if embeddings is None:
|
|
logger.error(f"Failed to generate embeddings for image {image_id}")
|
|
return False
|
|
|
|
# Store embeddings in Qdrant
|
|
if qdrant_client:
|
|
point_id = str(uuid.uuid4())
|
|
|
|
# Prepare metadata
|
|
metadata = {
|
|
'image_id': image_id,
|
|
'team_id': team_id,
|
|
'storage_path': storage_path,
|
|
'created_at': datetime.utcnow().isoformat(),
|
|
'model': 'vertex-ai-multimodal'
|
|
}
|
|
|
|
# Create point for Qdrant
|
|
point = PointStruct(
|
|
id=point_id,
|
|
vector=embeddings.tolist(),
|
|
payload=metadata
|
|
)
|
|
|
|
# Upsert to Qdrant
|
|
qdrant_client.upsert(
|
|
collection_name=QDRANT_COLLECTION,
|
|
points=[point]
|
|
)
|
|
|
|
logger.info(f"Stored embeddings for image {image_id} in Qdrant with point ID {point_id}")
|
|
|
|
# Update Firestore with embedding info
|
|
update_image_embedding_info(image_id, point_id, 'vertex-ai-multimodal')
|
|
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error processing image {image_id}: {e}")
|
|
return False
|
|
|
|
def generate_image_embeddings(image_data: bytes) -> Optional[np.ndarray]:
|
|
"""
|
|
Generate image embeddings using Vertex AI multimodal embedding model
|
|
|
|
Args:
|
|
image_data: Binary image data
|
|
|
|
Returns:
|
|
Numpy array of embeddings or None if failed
|
|
"""
|
|
try:
|
|
# Basic validation of image data
|
|
if not image_data or len(image_data) == 0:
|
|
logger.error("Empty image data provided")
|
|
return None
|
|
|
|
# Check image size (limit to 10MB)
|
|
if len(image_data) > 10 * 1024 * 1024:
|
|
logger.warning(f"Large image detected: {len(image_data)} bytes")
|
|
|
|
# Validate image format using PIL
|
|
try:
|
|
pil_image = Image.open(io.BytesIO(image_data))
|
|
logger.info(f"Image format: {pil_image.format}, size: {pil_image.size}, mode: {pil_image.mode}")
|
|
|
|
# Check for blank/empty images
|
|
if pil_image.size[0] == 0 or pil_image.size[1] == 0:
|
|
logger.error("Image has zero dimensions")
|
|
return None
|
|
|
|
except Exception as e:
|
|
logger.error(f"Invalid image format: {e}")
|
|
return None
|
|
|
|
# Create Vertex AI image object
|
|
vertex_image = VertexImage(image_data)
|
|
|
|
# Use multimodal embedding model to get embeddings
|
|
model = MultiModalEmbeddingModel.from_pretrained("multimodalembedding@001")
|
|
embeddings = model.get_embeddings(image=vertex_image)
|
|
|
|
if embeddings is None or embeddings.image_embedding is None:
|
|
logger.error("Failed to generate embeddings - no image embedding returned")
|
|
return None
|
|
|
|
# Get the image embedding vector
|
|
embedding_vector = embeddings.image_embedding
|
|
|
|
# Convert to numpy array - DO NOT normalize Vertex AI embeddings
|
|
# This must match the behavior in the main embedding service
|
|
embeddings_array = np.array(embedding_vector, dtype=np.float32)
|
|
|
|
# Validate embedding quality
|
|
if np.any(np.isnan(embeddings_array)) or np.any(np.isinf(embeddings_array)):
|
|
logger.error("Generated embeddings contain NaN or infinite values")
|
|
return None
|
|
|
|
# Check if embedding is mostly zeros (might indicate processing issue)
|
|
zero_ratio = np.sum(embeddings_array == 0.0) / len(embeddings_array)
|
|
if zero_ratio > 0.9:
|
|
logger.warning(f"Embedding is {zero_ratio*100:.1f}% zeros - might indicate processing issue")
|
|
|
|
logger.info(f"Generated embeddings with shape: {embeddings_array.shape}")
|
|
logger.info(f"Embedding stats - min: {embeddings_array.min():.6f}, max: {embeddings_array.max():.6f}, norm: {np.linalg.norm(embeddings_array):.6f}")
|
|
|
|
return embeddings_array
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error generating embeddings: {e}")
|
|
return None
|
|
|
|
def update_image_status(image_id: str, status: str, retry_count: int, error_message: str = None):
|
|
"""
|
|
Update the image embedding status in Firestore
|
|
|
|
Args:
|
|
image_id: The ID of the image
|
|
status: The new status (processing, success, failed)
|
|
retry_count: Current retry count
|
|
error_message: Error message if status is failed
|
|
"""
|
|
try:
|
|
doc_ref = firestore_client.collection('images').document(image_id)
|
|
|
|
update_data = {
|
|
'embedding_status': status,
|
|
'embedding_retry_count': retry_count,
|
|
'embedding_last_attempt': datetime.utcnow()
|
|
}
|
|
|
|
if error_message:
|
|
update_data['embedding_error'] = error_message
|
|
|
|
if status == 'success':
|
|
update_data['has_embedding'] = True
|
|
update_data['embedding_error'] = None # Clear any previous error
|
|
|
|
doc_ref.update(update_data)
|
|
logger.info(f"Updated image {image_id} status to {status}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error updating image status: {e}")
|
|
|
|
def update_image_embedding_info(image_id: str, point_id: str, model: str):
|
|
"""
|
|
Update the image with embedding information
|
|
|
|
Args:
|
|
image_id: The ID of the image
|
|
point_id: The ID of the point in the Qdrant vector database
|
|
model: The model used to generate embeddings
|
|
"""
|
|
try:
|
|
doc_ref = firestore_client.collection('images').document(image_id)
|
|
|
|
update_data = {
|
|
'embedding_point_id': point_id,
|
|
'embedding_model': model,
|
|
'has_embedding': True
|
|
}
|
|
|
|
doc_ref.update(update_data)
|
|
logger.info(f"Updated image {image_id} with embedding info")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error updating image embedding info: {e}") |