263 lines
9.5 KiB
Python
263 lines
9.5 KiB
Python
import json
|
|
import logging
|
|
import base64
|
|
from datetime import datetime
|
|
from typing import Dict, Any, Optional
|
|
import functions_framework
|
|
from google.cloud import vision
|
|
from google.cloud import firestore
|
|
from google.cloud import storage
|
|
import pinecone
|
|
import numpy as np
|
|
from PIL import Image
|
|
import io
|
|
import os
|
|
|
|
# Configure logging
|
|
logging.basicConfig(level=logging.INFO)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Initialize clients
|
|
vision_client = vision.ImageAnnotatorClient()
|
|
firestore_client = firestore.Client()
|
|
storage_client = storage.Client()
|
|
|
|
# Initialize Pinecone
|
|
PINECONE_API_KEY = os.environ.get('PINECONE_API_KEY')
|
|
PINECONE_ENVIRONMENT = os.environ.get('PINECONE_ENVIRONMENT')
|
|
PINECONE_INDEX_NAME = os.environ.get('PINECONE_INDEX_NAME', 'image-embeddings')
|
|
|
|
if PINECONE_API_KEY and PINECONE_ENVIRONMENT:
|
|
pinecone.init(api_key=PINECONE_API_KEY, environment=PINECONE_ENVIRONMENT)
|
|
index = pinecone.Index(PINECONE_INDEX_NAME)
|
|
else:
|
|
index = None
|
|
logger.warning("Pinecone not configured, embeddings will not be stored")
|
|
|
|
@functions_framework.cloud_event
|
|
def process_image_embedding(cloud_event):
|
|
"""
|
|
Cloud Function triggered by Pub/Sub to process image embeddings
|
|
"""
|
|
try:
|
|
# Decode the Pub/Sub message
|
|
message_data = base64.b64decode(cloud_event.data["message"]["data"]).decode('utf-8')
|
|
message = json.loads(message_data)
|
|
|
|
logger.info(f"Processing image embedding task: {message}")
|
|
|
|
# Extract task data
|
|
image_id = message.get('image_id')
|
|
storage_path = message.get('storage_path')
|
|
team_id = message.get('team_id')
|
|
retry_count = message.get('retry_count', 0)
|
|
|
|
if not all([image_id, storage_path, team_id]):
|
|
logger.error(f"Missing required fields in message: {message}")
|
|
return
|
|
|
|
# Update image status to processing
|
|
update_image_status(image_id, 'processing', retry_count)
|
|
|
|
# Process the image
|
|
success = process_image(image_id, storage_path, team_id, retry_count)
|
|
|
|
if success:
|
|
logger.info(f"Successfully processed image {image_id}")
|
|
update_image_status(image_id, 'success', retry_count)
|
|
else:
|
|
logger.error(f"Failed to process image {image_id}")
|
|
update_image_status(image_id, 'failed', retry_count, "Processing failed")
|
|
|
|
# Retry logic is handled by Pub/Sub retry policy
|
|
# The function will be retried automatically up to 3 times
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in process_image_embedding: {e}")
|
|
# Extract image_id if possible for status update
|
|
try:
|
|
message_data = base64.b64decode(cloud_event.data["message"]["data"]).decode('utf-8')
|
|
message = json.loads(message_data)
|
|
image_id = message.get('image_id')
|
|
retry_count = message.get('retry_count', 0)
|
|
if image_id:
|
|
update_image_status(image_id, 'failed', retry_count, str(e))
|
|
except:
|
|
pass
|
|
raise e
|
|
|
|
def process_image(image_id: str, storage_path: str, team_id: str, retry_count: int) -> bool:
|
|
"""
|
|
Process a single image to generate embeddings
|
|
|
|
Args:
|
|
image_id: The ID of the image to process
|
|
storage_path: The GCS path of the image
|
|
team_id: The team ID that owns the image
|
|
retry_count: Current retry count
|
|
|
|
Returns:
|
|
True if processing was successful, False otherwise
|
|
"""
|
|
try:
|
|
# Download image from Cloud Storage
|
|
bucket_name = storage_path.split('/')[0]
|
|
blob_path = '/'.join(storage_path.split('/')[1:])
|
|
|
|
bucket = storage_client.bucket(bucket_name)
|
|
blob = bucket.blob(blob_path)
|
|
|
|
if not blob.exists():
|
|
logger.error(f"Image not found in storage: {storage_path}")
|
|
return False
|
|
|
|
# Download image data
|
|
image_data = blob.download_as_bytes()
|
|
|
|
# Generate embeddings using Google Cloud Vision
|
|
embeddings = generate_image_embeddings(image_data)
|
|
|
|
if embeddings is None:
|
|
logger.error(f"Failed to generate embeddings for image {image_id}")
|
|
return False
|
|
|
|
# Store embeddings in Pinecone
|
|
if index:
|
|
embedding_id = f"{team_id}_{image_id}"
|
|
|
|
# Prepare metadata
|
|
metadata = {
|
|
'image_id': image_id,
|
|
'team_id': team_id,
|
|
'storage_path': storage_path,
|
|
'created_at': datetime.utcnow().isoformat()
|
|
}
|
|
|
|
# Upsert to Pinecone
|
|
index.upsert(vectors=[(embedding_id, embeddings.tolist(), metadata)])
|
|
|
|
logger.info(f"Stored embeddings for image {image_id} in Pinecone")
|
|
|
|
# Update Firestore with embedding info
|
|
update_image_embedding_info(image_id, embedding_id, 'google-vision-v1')
|
|
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error processing image {image_id}: {e}")
|
|
return False
|
|
|
|
def generate_image_embeddings(image_data: bytes) -> Optional[np.ndarray]:
|
|
"""
|
|
Generate image embeddings using Google Cloud Vision API
|
|
|
|
Args:
|
|
image_data: Binary image data
|
|
|
|
Returns:
|
|
Numpy array of embeddings or None if failed
|
|
"""
|
|
try:
|
|
# Create Vision API image object
|
|
image = vision.Image(content=image_data)
|
|
|
|
# Use object localization to get feature vectors
|
|
# This provides rich semantic information about the image
|
|
response = vision_client.object_localization(image=image)
|
|
|
|
if response.error.message:
|
|
logger.error(f"Vision API error: {response.error.message}")
|
|
return None
|
|
|
|
# Extract features from detected objects
|
|
features = []
|
|
|
|
# Get object detection features
|
|
for obj in response.localized_object_annotations:
|
|
# Use object name and confidence as features
|
|
features.extend([
|
|
hash(obj.name) % 1000 / 1000.0, # Normalized hash of object name
|
|
obj.score, # Confidence score
|
|
obj.bounding_poly.normalized_vertices[0].x, # Bounding box features
|
|
obj.bounding_poly.normalized_vertices[0].y,
|
|
obj.bounding_poly.normalized_vertices[2].x - obj.bounding_poly.normalized_vertices[0].x, # Width
|
|
obj.bounding_poly.normalized_vertices[2].y - obj.bounding_poly.normalized_vertices[0].y, # Height
|
|
])
|
|
|
|
# Also get label detection for additional semantic information
|
|
label_response = vision_client.label_detection(image=image)
|
|
|
|
for label in label_response.label_annotations[:10]: # Top 10 labels
|
|
features.extend([
|
|
hash(label.description) % 1000 / 1000.0, # Normalized hash of label
|
|
label.score # Confidence score
|
|
])
|
|
|
|
# Pad or truncate to fixed size (512 dimensions)
|
|
target_size = 512
|
|
if len(features) < target_size:
|
|
features.extend([0.0] * (target_size - len(features)))
|
|
else:
|
|
features = features[:target_size]
|
|
|
|
return np.array(features, dtype=np.float32)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error generating embeddings: {e}")
|
|
return None
|
|
|
|
def update_image_status(image_id: str, status: str, retry_count: int, error_message: str = None):
|
|
"""
|
|
Update the image embedding status in Firestore
|
|
|
|
Args:
|
|
image_id: The ID of the image
|
|
status: The new status (processing, success, failed)
|
|
retry_count: Current retry count
|
|
error_message: Error message if status is failed
|
|
"""
|
|
try:
|
|
doc_ref = firestore_client.collection('images').document(image_id)
|
|
|
|
update_data = {
|
|
'embedding_status': status,
|
|
'embedding_retry_count': retry_count,
|
|
'embedding_last_attempt': datetime.utcnow()
|
|
}
|
|
|
|
if error_message:
|
|
update_data['embedding_error'] = error_message
|
|
|
|
if status == 'success':
|
|
update_data['has_embedding'] = True
|
|
update_data['embedding_error'] = None # Clear any previous error
|
|
|
|
doc_ref.update(update_data)
|
|
logger.info(f"Updated image {image_id} status to {status}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error updating image status: {e}")
|
|
|
|
def update_image_embedding_info(image_id: str, embedding_id: str, model: str):
|
|
"""
|
|
Update the image with embedding information
|
|
|
|
Args:
|
|
image_id: The ID of the image
|
|
embedding_id: The ID of the embedding in the vector database
|
|
model: The model used to generate embeddings
|
|
"""
|
|
try:
|
|
doc_ref = firestore_client.collection('images').document(image_id)
|
|
|
|
update_data = {
|
|
'embedding_id': embedding_id,
|
|
'embedding_model': model,
|
|
'has_embedding': True
|
|
}
|
|
|
|
doc_ref.update(update_data)
|
|
logger.info(f"Updated image {image_id} with embedding info")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error updating image embedding info: {e}") |