203 lines
7.7 KiB
Python
203 lines
7.7 KiB
Python
import io
|
|
import logging
|
|
import os
|
|
from typing import List, Dict, Any, Union, Optional
|
|
import numpy as np
|
|
from PIL import Image
|
|
import vertexai
|
|
from vertexai.vision_models import MultiModalEmbeddingModel
|
|
|
|
from src.config.config import settings
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class EmbeddingService:
|
|
"""Service for generating image and text embeddings using Vertex AI multimodal model"""
|
|
|
|
def __init__(self):
|
|
self.model = None
|
|
self.model_name = "multimodalembedding@001"
|
|
self.embedding_dim = 1408 # Vertex AI multimodal embedding dimensions
|
|
self._initialize_vertex_ai()
|
|
|
|
def _initialize_vertex_ai(self):
|
|
"""Initialize Vertex AI"""
|
|
try:
|
|
project_id = os.environ.get('GOOGLE_CLOUD_PROJECT') or settings.FIRESTORE_PROJECT_ID
|
|
location = os.environ.get('VERTEX_AI_LOCATION', 'us-central1')
|
|
|
|
if project_id:
|
|
vertexai.init(project=project_id, location=location)
|
|
logger.info(f"Initialized Vertex AI with project {project_id} in location {location}")
|
|
else:
|
|
logger.error("PROJECT_ID not found in environment variables")
|
|
raise ValueError("Google Cloud Project ID not configured")
|
|
except Exception as e:
|
|
logger.error(f"Error initializing Vertex AI: {e}")
|
|
raise
|
|
|
|
def _load_model(self):
|
|
"""Load the Vertex AI multimodal embedding model"""
|
|
if self.model is None:
|
|
try:
|
|
logger.info(f"Loading Vertex AI multimodal embedding model: {self.model_name}")
|
|
self.model = MultiModalEmbeddingModel.from_pretrained(self.model_name)
|
|
logger.info("Vertex AI multimodal embedding model loaded successfully")
|
|
except Exception as e:
|
|
logger.error(f"Error loading Vertex AI model: {e}")
|
|
raise
|
|
|
|
def generate_image_embedding(self, image_data: bytes) -> List[float]:
|
|
"""
|
|
Generate embedding for an image using Vertex AI multimodal model
|
|
|
|
Args:
|
|
image_data: Binary image data
|
|
|
|
Returns:
|
|
Image embedding as a list of floats
|
|
"""
|
|
try:
|
|
self._load_model()
|
|
|
|
# Create Vertex AI image object
|
|
from vertexai.vision_models import Image as VertexImage
|
|
vertex_image = VertexImage(image_data)
|
|
|
|
# Generate image embedding using Vertex AI multimodal model
|
|
embeddings = self.model.get_embeddings(image=vertex_image)
|
|
|
|
if embeddings is None or embeddings.image_embedding is None:
|
|
logger.error("Failed to generate image embeddings - no image embedding returned")
|
|
raise ValueError("Failed to generate image embeddings")
|
|
|
|
# Get the image embedding vector
|
|
embedding_vector = embeddings.image_embedding
|
|
|
|
# Convert to numpy array and normalize
|
|
embeddings_array = np.array(embedding_vector, dtype=np.float32)
|
|
|
|
# Normalize the feature vector
|
|
norm = np.linalg.norm(embeddings_array)
|
|
if norm > 0:
|
|
embeddings_array = embeddings_array / norm
|
|
|
|
logger.info(f"Generated image embeddings with shape: {embeddings_array.shape}")
|
|
return embeddings_array.tolist()
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error generating image embedding: {e}")
|
|
raise
|
|
|
|
async def generate_text_embedding(self, text: str) -> Optional[List[float]]:
|
|
"""
|
|
Generate embedding for a text query using Vertex AI multimodal model
|
|
|
|
Args:
|
|
text: Text query
|
|
|
|
Returns:
|
|
Text embedding as a list of floats, or None if failed
|
|
"""
|
|
try:
|
|
self._load_model()
|
|
|
|
# Generate text embedding using Vertex AI multimodal model
|
|
embeddings = self.model.get_embeddings(contextual_text=text)
|
|
|
|
if embeddings is None or embeddings.text_embedding is None:
|
|
logger.error("Failed to generate text embeddings - no text embedding returned")
|
|
return None
|
|
|
|
# Get the text embedding vector
|
|
embedding_vector = embeddings.text_embedding
|
|
|
|
# Convert to numpy array and normalize
|
|
embeddings_array = np.array(embedding_vector, dtype=np.float32)
|
|
|
|
# Normalize the feature vector
|
|
norm = np.linalg.norm(embeddings_array)
|
|
if norm > 0:
|
|
embeddings_array = embeddings_array / norm
|
|
|
|
logger.info(f"Generated text embeddings with shape: {embeddings_array.shape}")
|
|
return embeddings_array.tolist()
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error generating text embedding: {e}")
|
|
return None
|
|
|
|
async def process_image_async(self, image_id: str, storage_path: str) -> bool:
|
|
"""
|
|
Process image asynchronously to generate embeddings
|
|
|
|
Args:
|
|
image_id: Image ID
|
|
storage_path: Path to image in storage
|
|
|
|
Returns:
|
|
True if processing started successfully
|
|
"""
|
|
try:
|
|
# In a real implementation, this would:
|
|
# 1. Publish a message to Pub/Sub queue
|
|
# 2. Cloud Function would pick up the message
|
|
# 3. Generate embeddings using Cloud Vision API
|
|
# 4. Store embeddings in Pinecone
|
|
# 5. Update image record with embedding info
|
|
|
|
logger.info(f"Starting async processing for image {image_id} at {storage_path}")
|
|
|
|
# For now, just log that processing would start
|
|
# In production, this would integrate with Google Pub/Sub
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f"Error starting async image processing: {e}")
|
|
return False
|
|
|
|
async def delete_embedding(self, embedding_id: str) -> bool:
|
|
"""
|
|
Delete embedding from vector database
|
|
|
|
Args:
|
|
embedding_id: Embedding ID in vector database
|
|
|
|
Returns:
|
|
True if deletion was successful
|
|
"""
|
|
try:
|
|
# In a real implementation, this would delete from Pinecone
|
|
logger.info(f"Deleting embedding {embedding_id} from vector database")
|
|
|
|
# Placeholder - in production this would call Pinecone API
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f"Error deleting embedding: {e}")
|
|
return False
|
|
|
|
def calculate_similarity(self, embedding1: List[float], embedding2: List[float]) -> float:
|
|
"""
|
|
Calculate cosine similarity between two embeddings
|
|
|
|
Args:
|
|
embedding1: First embedding
|
|
embedding2: Second embedding
|
|
|
|
Returns:
|
|
Cosine similarity (0-1)
|
|
"""
|
|
try:
|
|
# Convert to numpy arrays
|
|
vec1 = np.array(embedding1)
|
|
vec2 = np.array(embedding2)
|
|
|
|
# Calculate cosine similarity
|
|
similarity = np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2))
|
|
|
|
return float(similarity)
|
|
except Exception as e:
|
|
logger.error(f"Error calculating similarity: {e}")
|
|
raise
|
|
|
|
# Create a singleton service
|
|
embedding_service = EmbeddingService() |