cp
This commit is contained in:
parent
d6dea7ef17
commit
ae219c8111
34
.env.example
Normal file
34
.env.example
Normal file
@ -0,0 +1,34 @@
|
||||
# Application Settings
|
||||
PROJECT_NAME="Image Management API"
|
||||
API_V1_STR="/api/v1"
|
||||
LOG_LEVEL=INFO
|
||||
|
||||
# CORS Settings
|
||||
CORS_ORIGINS=*
|
||||
|
||||
# Database Settings
|
||||
# Choose database type: "mongodb" or "firestore"
|
||||
DATABASE_TYPE=mongodb
|
||||
|
||||
# MongoDB Settings (used when DATABASE_TYPE=mongodb)
|
||||
DATABASE_URI=mongodb://localhost:27017
|
||||
DATABASE_NAME=imagedb
|
||||
|
||||
# Google Cloud Firestore Settings (used when DATABASE_TYPE=firestore)
|
||||
# Path to service account credentials file (optional, uses application default credentials if not set)
|
||||
GCS_CREDENTIALS_FILE=path/to/credentials.json
|
||||
|
||||
# Google Cloud Storage Settings
|
||||
GCS_BUCKET_NAME=your-bucket-name
|
||||
|
||||
# Security Settings
|
||||
API_KEY_SECRET=super-secret-key-for-development-only
|
||||
API_KEY_EXPIRY_DAYS=365
|
||||
|
||||
# Vector Database Settings (for image embeddings)
|
||||
VECTOR_DB_API_KEY=your-pinecone-api-key
|
||||
VECTOR_DB_ENVIRONMENT=your-pinecone-environment
|
||||
VECTOR_DB_INDEX_NAME=image-embeddings
|
||||
|
||||
# Rate Limiting
|
||||
RATE_LIMIT_PER_MINUTE=100
|
||||
43
README.md
43
README.md
@ -44,11 +44,11 @@ sereact/
|
||||
## Technology Stack
|
||||
|
||||
- FastAPI - Web framework
|
||||
- MongoDB - Database
|
||||
- Firestore - Database
|
||||
- Google Cloud Storage - Image storage
|
||||
- Pinecone - Vector database for semantic search
|
||||
- CLIP - Image embedding model
|
||||
- PyTorch - Deep learning framework
|
||||
- NumPy - Scientific computing
|
||||
- Pydantic - Data validation
|
||||
|
||||
## Setup and Installation
|
||||
@ -56,8 +56,7 @@ sereact/
|
||||
### Prerequisites
|
||||
|
||||
- Python 3.8+
|
||||
- MongoDB
|
||||
- Google Cloud account with Storage enabled
|
||||
- Google Cloud account with Firestore and Storage enabled
|
||||
- (Optional) Pinecone account for semantic search
|
||||
|
||||
### Installation
|
||||
@ -82,9 +81,10 @@ sereact/
|
||||
|
||||
4. Create a `.env` file with the following environment variables:
|
||||
```
|
||||
# MongoDB
|
||||
DATABASE_URI=mongodb://localhost:27017
|
||||
# Firestore
|
||||
DATABASE_NAME=imagedb
|
||||
FIRESTORE_PROJECT_ID=your-gcp-project-id
|
||||
FIRESTORE_CREDENTIALS_FILE=path/to/firestore-credentials.json
|
||||
|
||||
# Google Cloud Storage
|
||||
GCS_BUCKET_NAME=your-bucket-name
|
||||
@ -151,6 +151,37 @@ pytest
|
||||
gcloud run deploy sereact --image gcr.io/your-project/sereact --platform managed
|
||||
```
|
||||
|
||||
## Local Development with Docker Compose
|
||||
|
||||
To run the application locally using Docker Compose:
|
||||
|
||||
1. Make sure you have Docker and Docker Compose installed
|
||||
2. Run the following command in the project root:
|
||||
|
||||
```bash
|
||||
docker compose up
|
||||
```
|
||||
|
||||
This will:
|
||||
- Build the API container based on the Dockerfile
|
||||
- Mount your local codebase into the container for live reloading
|
||||
- Mount your Firestore credentials for authentication
|
||||
- Expose the API on http://localhost:8000
|
||||
|
||||
To stop the containers:
|
||||
|
||||
```bash
|
||||
docker compose down
|
||||
```
|
||||
|
||||
To rebuild containers after making changes to the Dockerfile or requirements:
|
||||
|
||||
```bash
|
||||
docker compose up --build
|
||||
```
|
||||
|
||||
## Additional Information
|
||||
|
||||
## License
|
||||
|
||||
This project is licensed under the MIT License - see the LICENSE file for details.
|
||||
17
docker-compose.yml
Normal file
17
docker-compose.yml
Normal file
@ -0,0 +1,17 @@
|
||||
version: '3.8'
|
||||
|
||||
services:
|
||||
api:
|
||||
build: .
|
||||
ports:
|
||||
- "8000:8000"
|
||||
volumes:
|
||||
- .:/app
|
||||
- ${GOOGLE_APPLICATION_CREDENTIALS:-./firestore-credentials.json}:/app/firestore-credentials.json:ro
|
||||
environment:
|
||||
- PYTHONUNBUFFERED=1
|
||||
- ENVIRONMENT=development
|
||||
- DATABASE_NAME=imagedb
|
||||
- FIRESTORE_CREDENTIALS_FILE=/app/firestore-credentials.json
|
||||
- GOOGLE_APPLICATION_CREDENTIALS=/app/firestore-credentials.json
|
||||
command: uvicorn main:app --host 0.0.0.0 --port 8000 --reload
|
||||
@ -5,8 +5,7 @@ pydantic-settings==2.0.3
|
||||
python-dotenv==1.0.0
|
||||
google-cloud-storage==2.12.0
|
||||
google-cloud-vision==3.4.5
|
||||
pymongo==4.5.0
|
||||
motor==3.3.1
|
||||
google-cloud-firestore==2.11.1
|
||||
python-multipart==0.0.6
|
||||
python-jose==3.3.0
|
||||
passlib==1.7.4
|
||||
@ -15,7 +14,5 @@ pytest==7.4.3
|
||||
httpx==0.25.1
|
||||
pinecone-client==2.2.4
|
||||
pillow==10.1.0
|
||||
clip==0.2.0
|
||||
torch==2.1.0
|
||||
transformers==4.35.0
|
||||
python-slugify==8.0.1
|
||||
email-validator==2.1.0.post1
|
||||
|
||||
188
scripts/README.md
Normal file
188
scripts/README.md
Normal file
@ -0,0 +1,188 @@
|
||||
# Build and Deployment Scripts
|
||||
|
||||
This directory contains scripts for building and deploying the Sereact API application.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
- Docker installed and running
|
||||
- For deployment: access to a container registry (e.g., DockerHub, Google Container Registry)
|
||||
- For Cloud Run deployment: Google Cloud SDK (`gcloud`) installed and configured
|
||||
|
||||
## Scripts
|
||||
|
||||
### Build Script (`build.sh`)
|
||||
|
||||
Builds the Docker image for the Sereact API.
|
||||
|
||||
**Usage:**
|
||||
```bash
|
||||
# Basic usage (builds with default settings)
|
||||
./scripts/build.sh
|
||||
|
||||
# Customize the image name
|
||||
IMAGE_NAME=my-custom-name ./scripts/build.sh
|
||||
|
||||
# Customize the image tag
|
||||
IMAGE_TAG=v1.0.0 ./scripts/build.sh
|
||||
|
||||
# Use a custom registry
|
||||
REGISTRY=gcr.io/my-project ./scripts/build.sh
|
||||
```
|
||||
|
||||
**Environment Variables:**
|
||||
- `IMAGE_NAME`: Name for the Docker image (default: "sereact-api")
|
||||
- `IMAGE_TAG`: Tag for the Docker image (default: "latest")
|
||||
- `REGISTRY`: Container registry to use (default: empty, using DockerHub)
|
||||
|
||||
### Deploy Script (`deploy.sh`)
|
||||
|
||||
Pushes the built Docker image to a container registry and optionally deploys to Google Cloud Run.
|
||||
|
||||
**Usage:**
|
||||
```bash
|
||||
# Push to registry only
|
||||
./scripts/deploy.sh
|
||||
|
||||
# Push to registry and deploy to Cloud Run
|
||||
DEPLOY_TO_CLOUD_RUN=true PROJECT_ID=my-project-id ./scripts/deploy.sh
|
||||
|
||||
# Customize deployment settings
|
||||
DEPLOY_TO_CLOUD_RUN=true PROJECT_ID=my-project-id REGION=us-west1 SERVICE_NAME=my-api ./scripts/deploy.sh
|
||||
```
|
||||
|
||||
**Environment Variables:**
|
||||
All variables from the build script, plus:
|
||||
- `DEPLOY_TO_CLOUD_RUN`: Set to "true" to deploy to Cloud Run (default: "false")
|
||||
- `PROJECT_ID`: Google Cloud project ID (required for Cloud Run deployment)
|
||||
- `REGION`: Google Cloud region (default: "us-central1")
|
||||
- `SERVICE_NAME`: Name for the Cloud Run service (default: "sereact-api")
|
||||
|
||||
### Cloud Run Deployment Script (`deploy-to-cloud-run.sh`)
|
||||
|
||||
Deploys the application to Google Cloud Run using the service configuration file.
|
||||
|
||||
**Usage:**
|
||||
```bash
|
||||
# Deploy using existing service.yaml
|
||||
PROJECT_ID=my-project-id ./scripts/deploy-to-cloud-run.sh
|
||||
|
||||
# Build, push, and deploy in one command
|
||||
PROJECT_ID=my-project-id BUILD=true PUSH=true ./scripts/deploy-to-cloud-run.sh
|
||||
|
||||
# Customize the deployment
|
||||
PROJECT_ID=my-project-id REGION=us-west1 IMAGE_TAG=v1.0.0 ./scripts/deploy-to-cloud-run.sh
|
||||
```
|
||||
|
||||
**Environment Variables:**
|
||||
- `PROJECT_ID`: Google Cloud project ID (required)
|
||||
- `REGION`: Google Cloud region (default: "us-central1")
|
||||
- `SERVICE_CONFIG`: Path to the service configuration file (default: "deployment/cloud-run/service.yaml")
|
||||
- `IMAGE_NAME`: Name for the Docker image (default: "sereact-api")
|
||||
- `IMAGE_TAG`: Tag for the Docker image (default: "latest")
|
||||
- `REGISTRY`: Container registry to use (default: "gcr.io")
|
||||
- `BUILD`: Set to "true" to build the image before deployment (default: "false")
|
||||
- `PUSH`: Set to "true" to push the image before deployment (default: "false")
|
||||
|
||||
## Example Workflows
|
||||
|
||||
### Basic workflow:
|
||||
```bash
|
||||
# Build and tag with version
|
||||
IMAGE_TAG=v1.0.0 ./scripts/build.sh
|
||||
|
||||
# Deploy to Cloud Run
|
||||
DEPLOY_TO_CLOUD_RUN=true PROJECT_ID=my-project-id IMAGE_TAG=v1.0.0 ./scripts/deploy.sh
|
||||
```
|
||||
|
||||
### Using the Cloud Run config file:
|
||||
```bash
|
||||
# Build and deploy in one step
|
||||
PROJECT_ID=my-project-id BUILD=true PUSH=true ./scripts/deploy-to-cloud-run.sh
|
||||
```
|
||||
|
||||
# Scripts Documentation
|
||||
|
||||
This directory contains utility scripts for the SEREACT application.
|
||||
|
||||
## Database Seeding Scripts
|
||||
|
||||
### `seed_firestore.py`
|
||||
|
||||
This script initializes and seeds a Google Cloud Firestore database with initial data for the SEREACT application. It creates teams, users, API keys, and sample image metadata.
|
||||
|
||||
#### Requirements
|
||||
|
||||
- Google Cloud project with Firestore enabled
|
||||
- Google Cloud credentials configured on your machine
|
||||
- Python 3.8+
|
||||
- Required Python packages (listed in `requirements.txt`)
|
||||
|
||||
#### Setup
|
||||
|
||||
1. Make sure you have the Google Cloud SDK installed and configured with access to your project:
|
||||
```bash
|
||||
gcloud auth login
|
||||
gcloud config set project YOUR_PROJECT_ID
|
||||
```
|
||||
|
||||
2. If not using application default credentials, create a service account key file:
|
||||
```bash
|
||||
gcloud iam service-accounts create sereact-app
|
||||
gcloud projects add-iam-policy-binding YOUR_PROJECT_ID --member="serviceAccount:sereact-app@YOUR_PROJECT_ID.iam.gserviceaccount.com" --role="roles/datastore.user"
|
||||
gcloud iam service-accounts keys create credentials.json --iam-account=sereact-app@YOUR_PROJECT_ID.iam.gserviceaccount.com
|
||||
```
|
||||
|
||||
3. Set environment variables:
|
||||
```bash
|
||||
# Windows (CMD)
|
||||
set DATABASE_TYPE=firestore
|
||||
set GCS_CREDENTIALS_FILE=path/to/credentials.json
|
||||
|
||||
# Windows (PowerShell)
|
||||
$env:DATABASE_TYPE="firestore"
|
||||
$env:GCS_CREDENTIALS_FILE="path/to/credentials.json"
|
||||
|
||||
# Linux/macOS
|
||||
export DATABASE_TYPE=firestore
|
||||
export GCS_CREDENTIALS_FILE=path/to/credentials.json
|
||||
```
|
||||
|
||||
#### Usage
|
||||
|
||||
Run the seeding script from the project root directory:
|
||||
|
||||
```bash
|
||||
# Activate the Python virtual environment
|
||||
source venv/bin/activate # Linux/macOS
|
||||
venv\Scripts\activate # Windows
|
||||
|
||||
# Run the script
|
||||
python scripts/seed_firestore.py
|
||||
```
|
||||
|
||||
#### Generated Data
|
||||
|
||||
The script will create the following data:
|
||||
|
||||
1. **Teams**:
|
||||
- Sereact Development
|
||||
- Marketing Team
|
||||
- Customer Support
|
||||
|
||||
2. **Users**:
|
||||
- Admin User (team: Sereact Development)
|
||||
- Developer User (team: Sereact Development)
|
||||
- Marketing User (team: Marketing Team)
|
||||
- Support User (team: Customer Support)
|
||||
|
||||
3. **API Keys**:
|
||||
- One API key per user (the keys will be output to the console, save them securely)
|
||||
|
||||
4. **Images**:
|
||||
- Sample image metadata (3 images, one for each team)
|
||||
|
||||
#### Notes
|
||||
|
||||
- The script logs the generated API keys to the console. Save these keys somewhere secure as they won't be displayed again.
|
||||
- If you need to re-run the script with existing data, use the `--force` flag to overwrite existing data.
|
||||
- This script only creates metadata entries for images - it does not upload actual files to Google Cloud Storage.
|
||||
29
scripts/build.sh
Normal file
29
scripts/build.sh
Normal file
@ -0,0 +1,29 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
# Set defaults
|
||||
IMAGE_NAME=${IMAGE_NAME:-"sereact-api"}
|
||||
IMAGE_TAG=${IMAGE_TAG:-"latest"}
|
||||
|
||||
# Allow custom registry (defaults to DockerHub)
|
||||
REGISTRY=${REGISTRY:-""}
|
||||
REGISTRY_PREFIX=""
|
||||
if [ -n "$REGISTRY" ]; then
|
||||
REGISTRY_PREFIX="${REGISTRY}/"
|
||||
fi
|
||||
|
||||
# Full image reference
|
||||
FULL_IMAGE_NAME="${REGISTRY_PREFIX}${IMAGE_NAME}:${IMAGE_TAG}"
|
||||
|
||||
echo "Building Docker image: ${FULL_IMAGE_NAME}"
|
||||
|
||||
# Build the Docker image
|
||||
docker build -t "${FULL_IMAGE_NAME}" -f Dockerfile .
|
||||
|
||||
echo "Build completed successfully"
|
||||
echo "Image: ${FULL_IMAGE_NAME}"
|
||||
|
||||
# Print run command for testing locally
|
||||
echo ""
|
||||
echo "To run the image locally:"
|
||||
echo "docker run -p 8000:8000 ${FULL_IMAGE_NAME}"
|
||||
57
scripts/deploy-to-cloud-run.sh
Normal file
57
scripts/deploy-to-cloud-run.sh
Normal file
@ -0,0 +1,57 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
# Default values
|
||||
PROJECT_ID=${PROJECT_ID:-""}
|
||||
REGION=${REGION:-"us-central1"}
|
||||
SERVICE_CONFIG=${SERVICE_CONFIG:-"deployment/cloud-run/service.yaml"}
|
||||
IMAGE_NAME=${IMAGE_NAME:-"sereact-api"}
|
||||
IMAGE_TAG=${IMAGE_TAG:-"latest"}
|
||||
REGISTRY=${REGISTRY:-"gcr.io"}
|
||||
|
||||
# Validate required parameters
|
||||
if [ -z "$PROJECT_ID" ]; then
|
||||
echo "Error: PROJECT_ID environment variable is required"
|
||||
echo "Usage: PROJECT_ID=your-project-id ./scripts/deploy-to-cloud-run.sh"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Full image reference
|
||||
FULL_IMAGE_NAME="${REGISTRY}/${PROJECT_ID}/${IMAGE_NAME}:${IMAGE_TAG}"
|
||||
|
||||
# Check if service config exists
|
||||
if [ ! -f "$SERVICE_CONFIG" ]; then
|
||||
echo "Error: Service configuration file not found at $SERVICE_CONFIG"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Build the image if BUILD=true
|
||||
if [ "${BUILD:-false}" = "true" ]; then
|
||||
echo "Building Docker image: ${FULL_IMAGE_NAME}"
|
||||
docker build -t "${FULL_IMAGE_NAME}" -f Dockerfile .
|
||||
echo "Build completed successfully"
|
||||
fi
|
||||
|
||||
# Push the image if PUSH=true
|
||||
if [ "${PUSH:-false}" = "true" ]; then
|
||||
echo "Pushing image to registry: ${FULL_IMAGE_NAME}"
|
||||
docker push "${FULL_IMAGE_NAME}"
|
||||
echo "Image pushed successfully"
|
||||
fi
|
||||
|
||||
# Update the image in the service configuration
|
||||
echo "Updating image reference in service configuration..."
|
||||
TMP_CONFIG=$(mktemp)
|
||||
sed "s|image: .*|image: ${FULL_IMAGE_NAME}|g" "$SERVICE_CONFIG" > "$TMP_CONFIG"
|
||||
|
||||
echo "Deploying to Cloud Run using configuration..."
|
||||
gcloud run services replace "$TMP_CONFIG" \
|
||||
--project="$PROJECT_ID" \
|
||||
--region="$REGION" \
|
||||
--platform=managed
|
||||
|
||||
rm "$TMP_CONFIG"
|
||||
|
||||
echo "Deployment completed successfully"
|
||||
echo "Service URL: $(gcloud run services describe sereact --region=${REGION} --project=${PROJECT_ID} --format='value(status.url)')"
|
||||
echo "To view logs: gcloud logging read 'resource.type=cloud_run_revision AND resource.labels.service_name=sereact' --project=$PROJECT_ID --limit=10"
|
||||
43
scripts/deploy.sh
Normal file
43
scripts/deploy.sh
Normal file
@ -0,0 +1,43 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
# Source the build environment to reuse variables
|
||||
source "$(dirname "$0")/build.sh"
|
||||
|
||||
# Push the Docker image to the registry
|
||||
echo "Pushing image: ${FULL_IMAGE_NAME} to registry..."
|
||||
docker push "${FULL_IMAGE_NAME}"
|
||||
echo "Image pushed successfully"
|
||||
|
||||
# Check if we need to deploy to Cloud Run
|
||||
DEPLOY_TO_CLOUD_RUN=${DEPLOY_TO_CLOUD_RUN:-false}
|
||||
|
||||
if [ "$DEPLOY_TO_CLOUD_RUN" = true ]; then
|
||||
echo "Deploying to Cloud Run..."
|
||||
|
||||
# Cloud Run settings
|
||||
PROJECT_ID=${PROJECT_ID:-""}
|
||||
REGION=${REGION:-"us-central1"}
|
||||
SERVICE_NAME=${SERVICE_NAME:-"sereact-api"}
|
||||
|
||||
if [ -z "$PROJECT_ID" ]; then
|
||||
echo "Error: PROJECT_ID environment variable is required for Cloud Run deployment"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Deploy to Cloud Run
|
||||
gcloud run deploy "${SERVICE_NAME}" \
|
||||
--image="${FULL_IMAGE_NAME}" \
|
||||
--platform=managed \
|
||||
--region="${REGION}" \
|
||||
--project="${PROJECT_ID}" \
|
||||
--allow-unauthenticated \
|
||||
--port=8000
|
||||
|
||||
echo "Deployment to Cloud Run completed"
|
||||
echo "Service URL: $(gcloud run services describe ${SERVICE_NAME} --region=${REGION} --project=${PROJECT_ID} --format='value(status.url)')"
|
||||
else
|
||||
echo ""
|
||||
echo "To deploy to Cloud Run:"
|
||||
echo "DEPLOY_TO_CLOUD_RUN=true PROJECT_ID=your-project-id ./scripts/deploy.sh"
|
||||
fi
|
||||
272
scripts/seed_firestore.py
Normal file
272
scripts/seed_firestore.py
Normal file
@ -0,0 +1,272 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Script to seed the Firestore database with initial data.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import asyncio
|
||||
import logging
|
||||
import argparse
|
||||
from datetime import datetime, timedelta
|
||||
import secrets
|
||||
import hashlib
|
||||
from bson import ObjectId
|
||||
from pydantic import HttpUrl
|
||||
|
||||
# Add the parent directory to the path so we can import from src
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
||||
|
||||
from src.db.models.team import TeamModel
|
||||
from src.db.models.user import UserModel
|
||||
from src.db.models.api_key import ApiKeyModel
|
||||
from src.db.models.image import ImageModel
|
||||
from src.db.providers.firestore_provider import firestore_db
|
||||
from src.db.repositories.firestore_team_repository import firestore_team_repository
|
||||
from src.db.repositories.firestore_user_repository import firestore_user_repository
|
||||
from src.db.repositories.firestore_api_key_repository import firestore_api_key_repository
|
||||
from src.db.repositories.firestore_image_repository import firestore_image_repository
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def generate_api_key(length=32):
|
||||
"""Generate a random API key"""
|
||||
return secrets.token_hex(length)
|
||||
|
||||
def hash_api_key(api_key):
|
||||
"""Hash an API key for storage"""
|
||||
return hashlib.sha256(api_key.encode()).hexdigest()
|
||||
|
||||
async def seed_teams():
|
||||
"""Seed the database with team data"""
|
||||
logger.info("Seeding teams...")
|
||||
|
||||
teams_data = [
|
||||
{
|
||||
"name": "Sereact Development",
|
||||
"description": "Internal development team"
|
||||
},
|
||||
{
|
||||
"name": "Marketing Team",
|
||||
"description": "Marketing and design team"
|
||||
},
|
||||
{
|
||||
"name": "Customer Support",
|
||||
"description": "Customer support and success team"
|
||||
}
|
||||
]
|
||||
|
||||
team_ids = []
|
||||
for team_data in teams_data:
|
||||
team = TeamModel(**team_data)
|
||||
created_team = await firestore_team_repository.create(team)
|
||||
team_ids.append(created_team.id)
|
||||
logger.info(f"Created team: {created_team.name} (ID: {created_team.id})")
|
||||
|
||||
return team_ids
|
||||
|
||||
async def seed_users(team_ids):
|
||||
"""Seed the database with user data"""
|
||||
logger.info("Seeding users...")
|
||||
|
||||
users_data = [
|
||||
{
|
||||
"email": "admin@sereact.com",
|
||||
"name": "Admin User",
|
||||
"team_id": team_ids[0],
|
||||
"is_admin": True
|
||||
},
|
||||
{
|
||||
"email": "developer@sereact.com",
|
||||
"name": "Developer User",
|
||||
"team_id": team_ids[0]
|
||||
},
|
||||
{
|
||||
"email": "marketing@sereact.com",
|
||||
"name": "Marketing User",
|
||||
"team_id": team_ids[1]
|
||||
},
|
||||
{
|
||||
"email": "support@sereact.com",
|
||||
"name": "Support User",
|
||||
"team_id": team_ids[2]
|
||||
}
|
||||
]
|
||||
|
||||
user_ids = []
|
||||
for user_data in users_data:
|
||||
user = UserModel(**user_data)
|
||||
created_user = await firestore_user_repository.create(user)
|
||||
user_ids.append(created_user.id)
|
||||
logger.info(f"Created user: {created_user.name} (ID: {created_user.id})")
|
||||
|
||||
return user_ids
|
||||
|
||||
async def seed_api_keys(user_ids, team_ids):
|
||||
"""Seed the database with API key data"""
|
||||
logger.info("Seeding API keys...")
|
||||
|
||||
api_keys_data = [
|
||||
{
|
||||
"user_id": user_ids[0],
|
||||
"team_id": team_ids[0],
|
||||
"name": "Admin Key",
|
||||
"description": "API key for admin user"
|
||||
},
|
||||
{
|
||||
"user_id": user_ids[1],
|
||||
"team_id": team_ids[0],
|
||||
"name": "Development Key",
|
||||
"description": "API key for development user"
|
||||
},
|
||||
{
|
||||
"user_id": user_ids[2],
|
||||
"team_id": team_ids[1],
|
||||
"name": "Marketing Key",
|
||||
"description": "API key for marketing user"
|
||||
},
|
||||
{
|
||||
"user_id": user_ids[3],
|
||||
"team_id": team_ids[2],
|
||||
"name": "Support Key",
|
||||
"description": "API key for support user"
|
||||
}
|
||||
]
|
||||
|
||||
generated_keys = []
|
||||
for api_key_data in api_keys_data:
|
||||
# Generate a unique API key
|
||||
api_key = generate_api_key()
|
||||
key_hash = hash_api_key(api_key)
|
||||
|
||||
# Create API key object
|
||||
api_key_data["key_hash"] = key_hash
|
||||
api_key_data["expiry_date"] = datetime.utcnow() + timedelta(days=365)
|
||||
|
||||
api_key_obj = ApiKeyModel(**api_key_data)
|
||||
created_api_key = await firestore_api_key_repository.create(api_key_obj)
|
||||
|
||||
generated_keys.append({
|
||||
"id": created_api_key.id,
|
||||
"key": api_key,
|
||||
"name": created_api_key.name
|
||||
})
|
||||
|
||||
logger.info(f"Created API key: {created_api_key.name} (ID: {created_api_key.id})")
|
||||
|
||||
# Print the generated keys for reference
|
||||
logger.info("\nGenerated API Keys (save these somewhere secure):")
|
||||
for key in generated_keys:
|
||||
logger.info(f"Name: {key['name']}, Key: {key['key']}")
|
||||
|
||||
return generated_keys
|
||||
|
||||
async def seed_images(team_ids, user_ids):
|
||||
"""Seed the database with image metadata"""
|
||||
logger.info("Seeding images...")
|
||||
|
||||
images_data = [
|
||||
{
|
||||
"filename": "image1.jpg",
|
||||
"original_filename": "product_photo.jpg",
|
||||
"file_size": 1024 * 1024, # 1MB
|
||||
"content_type": "image/jpeg",
|
||||
"storage_path": "teams/{}/images/image1.jpg".format(team_ids[0]),
|
||||
"public_url": "https://storage.googleapis.com/example-bucket/teams/{}/images/image1.jpg".format(team_ids[0]),
|
||||
"team_id": team_ids[0],
|
||||
"uploader_id": user_ids[0],
|
||||
"description": "Product photo for marketing",
|
||||
"tags": ["product", "marketing", "high-resolution"],
|
||||
"metadata": {
|
||||
"width": 1920,
|
||||
"height": 1080,
|
||||
"color_space": "sRGB"
|
||||
}
|
||||
},
|
||||
{
|
||||
"filename": "image2.png",
|
||||
"original_filename": "logo.png",
|
||||
"file_size": 512 * 1024, # 512KB
|
||||
"content_type": "image/png",
|
||||
"storage_path": "teams/{}/images/image2.png".format(team_ids[1]),
|
||||
"public_url": "https://storage.googleapis.com/example-bucket/teams/{}/images/image2.png".format(team_ids[1]),
|
||||
"team_id": team_ids[1],
|
||||
"uploader_id": user_ids[2],
|
||||
"description": "Company logo",
|
||||
"tags": ["logo", "branding"],
|
||||
"metadata": {
|
||||
"width": 800,
|
||||
"height": 600,
|
||||
"color_space": "sRGB"
|
||||
}
|
||||
},
|
||||
{
|
||||
"filename": "image3.jpg",
|
||||
"original_filename": "support_screenshot.jpg",
|
||||
"file_size": 256 * 1024, # 256KB
|
||||
"content_type": "image/jpeg",
|
||||
"storage_path": "teams/{}/images/image3.jpg".format(team_ids[2]),
|
||||
"public_url": "https://storage.googleapis.com/example-bucket/teams/{}/images/image3.jpg".format(team_ids[2]),
|
||||
"team_id": team_ids[2],
|
||||
"uploader_id": user_ids[3],
|
||||
"description": "Screenshot for support ticket",
|
||||
"tags": ["support", "screenshot", "bug"],
|
||||
"metadata": {
|
||||
"width": 1280,
|
||||
"height": 720,
|
||||
"color_space": "sRGB"
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
image_ids = []
|
||||
for image_data in images_data:
|
||||
image = ImageModel(**image_data)
|
||||
created_image = await firestore_image_repository.create(image)
|
||||
image_ids.append(created_image.id)
|
||||
logger.info(f"Created image: {created_image.filename} (ID: {created_image.id})")
|
||||
|
||||
return image_ids
|
||||
|
||||
async def seed_database():
|
||||
"""Seed the database with initial data"""
|
||||
try:
|
||||
# Connect to Firestore
|
||||
firestore_db.connect()
|
||||
|
||||
# Seed teams first
|
||||
team_ids = await seed_teams()
|
||||
|
||||
# Seed users with team IDs
|
||||
user_ids = await seed_users(team_ids)
|
||||
|
||||
# Seed API keys with user and team IDs
|
||||
api_keys = await seed_api_keys(user_ids, team_ids)
|
||||
|
||||
# Seed images with team and user IDs
|
||||
image_ids = await seed_images(team_ids, user_ids)
|
||||
|
||||
logger.info("Database seeding completed successfully!")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error seeding database: {e}")
|
||||
raise
|
||||
finally:
|
||||
# Disconnect from Firestore
|
||||
firestore_db.disconnect()
|
||||
|
||||
def main():
|
||||
"""Main entry point"""
|
||||
parser = argparse.ArgumentParser(description="Seed the Firestore database with initial data")
|
||||
parser.add_argument("--force", action="store_true", help="Force seeding even if data exists")
|
||||
args = parser.parse_args()
|
||||
|
||||
asyncio.run(seed_database())
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@ -0,0 +1,2 @@
|
||||
from src.api.v1 import teams, auth
|
||||
from src.api.v1 import users, images, search
|
||||
13
src/api/v1/images.py
Normal file
13
src/api/v1/images.py
Normal file
@ -0,0 +1,13 @@
|
||||
import logging
|
||||
from fastapi import APIRouter, Depends
|
||||
|
||||
from src.api.v1.auth import get_current_user
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(tags=["Images"], prefix="/images")
|
||||
|
||||
@router.get("")
|
||||
async def list_images(current_user = Depends(get_current_user)):
|
||||
"""List images (placeholder endpoint)"""
|
||||
return {"message": "Images listing functionality to be implemented"}
|
||||
16
src/api/v1/search.py
Normal file
16
src/api/v1/search.py
Normal file
@ -0,0 +1,16 @@
|
||||
import logging
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
|
||||
from src.api.v1.auth import get_current_user
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(tags=["Search"], prefix="/search")
|
||||
|
||||
@router.get("")
|
||||
async def search_images(
|
||||
q: str = Query(..., description="Search query"),
|
||||
current_user = Depends(get_current_user)
|
||||
):
|
||||
"""Search for images (placeholder endpoint)"""
|
||||
return {"message": "Search functionality to be implemented", "query": q}
|
||||
13
src/api/v1/users.py
Normal file
13
src/api/v1/users.py
Normal file
@ -0,0 +1,13 @@
|
||||
import logging
|
||||
from fastapi import APIRouter, Depends
|
||||
|
||||
from src.api.v1.auth import get_current_user
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(tags=["Users"], prefix="/users")
|
||||
|
||||
@router.get("/me")
|
||||
async def read_users_me(current_user = Depends(get_current_user)):
|
||||
"""Get current user information"""
|
||||
return current_user
|
||||
@ -1,7 +1,7 @@
|
||||
import os
|
||||
from typing import List
|
||||
from typing import List, ClassVar
|
||||
from pydantic_settings import BaseSettings
|
||||
from pydantic import AnyHttpUrl, validator
|
||||
from pydantic import AnyHttpUrl, field_validator
|
||||
|
||||
class Settings(BaseSettings):
|
||||
# Project settings
|
||||
@ -11,15 +11,16 @@ class Settings(BaseSettings):
|
||||
# CORS settings
|
||||
CORS_ORIGINS: List[str] = ["*"]
|
||||
|
||||
@validator("CORS_ORIGINS", pre=True)
|
||||
@field_validator("CORS_ORIGINS", mode="before")
|
||||
def assemble_cors_origins(cls, v):
|
||||
if isinstance(v, str) and not v.startswith("["):
|
||||
return [i.strip() for i in v.split(",")]
|
||||
return v
|
||||
|
||||
# Database settings
|
||||
DATABASE_URI: str = os.getenv("DATABASE_URI", "mongodb://localhost:27017")
|
||||
DATABASE_NAME: str = os.getenv("DATABASE_NAME", "imagedb")
|
||||
FIRESTORE_PROJECT_ID: str = os.getenv("FIRESTORE_PROJECT_ID", "")
|
||||
FIRESTORE_CREDENTIALS_FILE: str = os.getenv("FIRESTORE_CREDENTIALS_FILE", "firestore-credentials.json")
|
||||
|
||||
# Google Cloud Storage settings
|
||||
GCS_BUCKET_NAME: str = os.getenv("GCS_BUCKET_NAME", "image-mgmt-bucket")
|
||||
@ -40,8 +41,9 @@ class Settings(BaseSettings):
|
||||
# Logging
|
||||
LOG_LEVEL: str = os.getenv("LOG_LEVEL", "INFO")
|
||||
|
||||
class Config:
|
||||
case_sensitive = True
|
||||
env_file = ".env"
|
||||
model_config: ClassVar[dict] = {
|
||||
"case_sensitive": True,
|
||||
"env_file": ".env"
|
||||
}
|
||||
|
||||
settings = Settings()
|
||||
@ -1,33 +1,33 @@
|
||||
import logging
|
||||
from motor.motor_asyncio import AsyncIOMotorClient
|
||||
from google.cloud import firestore
|
||||
from src.core.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class Database:
|
||||
client: AsyncIOMotorClient = None
|
||||
client = None
|
||||
|
||||
def connect_to_database(self):
|
||||
"""Create database connection."""
|
||||
try:
|
||||
self.client = AsyncIOMotorClient(settings.DATABASE_URI)
|
||||
logger.info("Connected to MongoDB")
|
||||
self.client = firestore.Client()
|
||||
logger.info("Connected to Firestore")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to connect to MongoDB: {e}")
|
||||
logger.error(f"Failed to connect to Firestore: {e}")
|
||||
raise
|
||||
|
||||
def close_database_connection(self):
|
||||
"""Close database connection."""
|
||||
try:
|
||||
if self.client:
|
||||
self.client.close()
|
||||
logger.info("Closed MongoDB connection")
|
||||
# No explicit close method needed for Firestore client
|
||||
self.client = None
|
||||
logger.info("Closed Firestore connection")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to close MongoDB connection: {e}")
|
||||
logger.error(f"Failed to close Firestore connection: {e}")
|
||||
|
||||
def get_database(self):
|
||||
"""Get the database instance."""
|
||||
return self.client[settings.DATABASE_NAME]
|
||||
return self.client
|
||||
|
||||
# Create a singleton database instance
|
||||
db = Database()
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
from typing import Optional, ClassVar
|
||||
from pydantic import BaseModel, Field
|
||||
from bson import ObjectId
|
||||
|
||||
@ -18,9 +18,10 @@ class ApiKeyModel(BaseModel):
|
||||
last_used: Optional[datetime] = None
|
||||
is_active: bool = True
|
||||
|
||||
class Config:
|
||||
allow_population_by_field_name = True
|
||||
arbitrary_types_allowed = True
|
||||
json_encoders = {
|
||||
model_config: ClassVar[dict] = {
|
||||
"populate_by_name": True,
|
||||
"arbitrary_types_allowed": True,
|
||||
"json_encoders": {
|
||||
ObjectId: str
|
||||
}
|
||||
}
|
||||
@ -1,5 +1,5 @@
|
||||
from datetime import datetime
|
||||
from typing import Optional, List, Dict, Any
|
||||
from typing import Optional, List, Dict, Any, ClassVar
|
||||
from pydantic import BaseModel, Field, HttpUrl
|
||||
from bson import ObjectId
|
||||
|
||||
@ -27,9 +27,10 @@ class ImageModel(BaseModel):
|
||||
embedding_model: Optional[str] = None
|
||||
has_embedding: bool = False
|
||||
|
||||
class Config:
|
||||
allow_population_by_field_name = True
|
||||
arbitrary_types_allowed = True
|
||||
json_encoders = {
|
||||
model_config: ClassVar[dict] = {
|
||||
"populate_by_name": True,
|
||||
"arbitrary_types_allowed": True,
|
||||
"json_encoders": {
|
||||
ObjectId: str
|
||||
}
|
||||
}
|
||||
@ -1,6 +1,7 @@
|
||||
from datetime import datetime
|
||||
from typing import Optional, List
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Optional, List, Any, ClassVar
|
||||
from pydantic import BaseModel, Field, GetJsonSchemaHandler
|
||||
from pydantic.json_schema import JsonSchemaValue
|
||||
from bson import ObjectId
|
||||
|
||||
class PyObjectId(ObjectId):
|
||||
@ -15,8 +16,12 @@ class PyObjectId(ObjectId):
|
||||
return ObjectId(v)
|
||||
|
||||
@classmethod
|
||||
def __modify_schema__(cls, field_schema):
|
||||
def __get_pydantic_json_schema__(
|
||||
cls, __core_schema: Any, __field_schema: Any, __handler: GetJsonSchemaHandler
|
||||
) -> JsonSchemaValue:
|
||||
field_schema = __handler(__core_schema)
|
||||
field_schema.update(type='string')
|
||||
return field_schema
|
||||
|
||||
class TeamModel(BaseModel):
|
||||
"""Database model for a team"""
|
||||
@ -26,9 +31,10 @@ class TeamModel(BaseModel):
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
updated_at: Optional[datetime] = None
|
||||
|
||||
class Config:
|
||||
allow_population_by_field_name = True
|
||||
arbitrary_types_allowed = True
|
||||
json_encoders = {
|
||||
model_config: ClassVar[dict] = {
|
||||
"populate_by_name": True,
|
||||
"arbitrary_types_allowed": True,
|
||||
"json_encoders": {
|
||||
ObjectId: str
|
||||
}
|
||||
}
|
||||
@ -1,5 +1,5 @@
|
||||
from datetime import datetime
|
||||
from typing import Optional, List
|
||||
from typing import Optional, List, ClassVar
|
||||
from pydantic import BaseModel, Field, EmailStr
|
||||
from bson import ObjectId
|
||||
|
||||
@ -17,9 +17,10 @@ class UserModel(BaseModel):
|
||||
updated_at: Optional[datetime] = None
|
||||
last_login: Optional[datetime] = None
|
||||
|
||||
class Config:
|
||||
allow_population_by_field_name = True
|
||||
arbitrary_types_allowed = True
|
||||
json_encoders = {
|
||||
model_config: ClassVar[dict] = {
|
||||
"populate_by_name": True,
|
||||
"arbitrary_types_allowed": True,
|
||||
"json_encoders": {
|
||||
ObjectId: str
|
||||
}
|
||||
}
|
||||
202
src/db/providers/firestore_provider.py
Normal file
202
src/db/providers/firestore_provider.py
Normal file
@ -0,0 +1,202 @@
|
||||
from typing import Any, Dict, List, Optional, Type
|
||||
import logging
|
||||
import os
|
||||
from google.cloud import firestore
|
||||
from pydantic import BaseModel
|
||||
|
||||
from src.core.config import settings
|
||||
from src.db.models.team import TeamModel
|
||||
from src.db.models.user import UserModel
|
||||
from src.db.models.api_key import ApiKeyModel
|
||||
from src.db.models.image import ImageModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class FirestoreProvider:
|
||||
"""Provider for Firestore database operations"""
|
||||
|
||||
def __init__(self):
|
||||
self.client = None
|
||||
self._db = None
|
||||
self._collections = {
|
||||
"teams": TeamModel,
|
||||
"users": UserModel,
|
||||
"api_keys": ApiKeyModel,
|
||||
"images": ImageModel
|
||||
}
|
||||
|
||||
def connect(self):
|
||||
"""Connect to Firestore"""
|
||||
try:
|
||||
if settings.GCS_CREDENTIALS_FILE and os.path.exists(settings.GCS_CREDENTIALS_FILE):
|
||||
self.client = firestore.Client.from_service_account_json(settings.GCS_CREDENTIALS_FILE)
|
||||
else:
|
||||
# Use application default credentials
|
||||
self.client = firestore.Client()
|
||||
|
||||
self._db = self.client
|
||||
logger.info("Connected to Firestore")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to connect to Firestore: {e}")
|
||||
raise
|
||||
|
||||
def disconnect(self):
|
||||
"""Disconnect from Firestore"""
|
||||
try:
|
||||
self.client = None
|
||||
self._db = None
|
||||
logger.info("Disconnected from Firestore")
|
||||
except Exception as e:
|
||||
logger.error(f"Error disconnecting from Firestore: {e}")
|
||||
|
||||
def get_collection(self, collection_name: str):
|
||||
"""Get a Firestore collection reference"""
|
||||
if not self._db:
|
||||
raise ValueError("Not connected to Firestore")
|
||||
return self._db.collection(collection_name)
|
||||
|
||||
async def add_document(self, collection_name: str, data: Dict[str, Any]) -> str:
|
||||
"""
|
||||
Add a document to a collection
|
||||
|
||||
Args:
|
||||
collection_name: Collection name
|
||||
data: Document data
|
||||
|
||||
Returns:
|
||||
Document ID
|
||||
"""
|
||||
try:
|
||||
collection = self.get_collection(collection_name)
|
||||
|
||||
# Handle ObjectId conversion for Firestore
|
||||
for key, value in data.items():
|
||||
if hasattr(value, '__str__') and key != 'id':
|
||||
data[key] = str(value)
|
||||
|
||||
# Handle special case for document ID
|
||||
doc_id = None
|
||||
if "_id" in data:
|
||||
doc_id = str(data["_id"])
|
||||
del data["_id"]
|
||||
|
||||
# Add document to Firestore
|
||||
if doc_id:
|
||||
doc_ref = collection.document(doc_id)
|
||||
doc_ref.set(data)
|
||||
return doc_id
|
||||
else:
|
||||
doc_ref = collection.add(data)
|
||||
return doc_ref[1].id
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding document to {collection_name}: {e}")
|
||||
raise
|
||||
|
||||
async def get_document(self, collection_name: str, doc_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Get a document by ID
|
||||
|
||||
Args:
|
||||
collection_name: Collection name
|
||||
doc_id: Document ID
|
||||
|
||||
Returns:
|
||||
Document data if found, None otherwise
|
||||
"""
|
||||
try:
|
||||
doc_ref = self.get_collection(collection_name).document(doc_id)
|
||||
doc = doc_ref.get()
|
||||
if doc.exists:
|
||||
data = doc.to_dict()
|
||||
data["_id"] = doc_id
|
||||
return data
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting document from {collection_name}: {e}")
|
||||
raise
|
||||
|
||||
async def list_documents(self, collection_name: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
List all documents in a collection
|
||||
|
||||
Args:
|
||||
collection_name: Collection name
|
||||
|
||||
Returns:
|
||||
List of documents
|
||||
"""
|
||||
try:
|
||||
docs = self.get_collection(collection_name).stream()
|
||||
results = []
|
||||
for doc in docs:
|
||||
data = doc.to_dict()
|
||||
data["_id"] = doc.id
|
||||
results.append(data)
|
||||
return results
|
||||
except Exception as e:
|
||||
logger.error(f"Error listing documents in {collection_name}: {e}")
|
||||
raise
|
||||
|
||||
async def update_document(self, collection_name: str, doc_id: str, data: Dict[str, Any]) -> bool:
|
||||
"""
|
||||
Update a document
|
||||
|
||||
Args:
|
||||
collection_name: Collection name
|
||||
doc_id: Document ID
|
||||
data: Update data
|
||||
|
||||
Returns:
|
||||
True if document was updated, False otherwise
|
||||
"""
|
||||
try:
|
||||
# Process data for Firestore
|
||||
processed_data = {}
|
||||
for key, value in data.items():
|
||||
if key != "_id" and hasattr(value, '__str__'):
|
||||
processed_data[key] = str(value)
|
||||
elif key != "_id":
|
||||
processed_data[key] = value
|
||||
|
||||
doc_ref = self.get_collection(collection_name).document(doc_id)
|
||||
doc_ref.update(processed_data)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating document in {collection_name}: {e}")
|
||||
raise
|
||||
|
||||
async def delete_document(self, collection_name: str, doc_id: str) -> bool:
|
||||
"""
|
||||
Delete a document
|
||||
|
||||
Args:
|
||||
collection_name: Collection name
|
||||
doc_id: Document ID
|
||||
|
||||
Returns:
|
||||
True if document was deleted, False otherwise
|
||||
"""
|
||||
try:
|
||||
doc_ref = self.get_collection(collection_name).document(doc_id)
|
||||
doc_ref.delete()
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting document from {collection_name}: {e}")
|
||||
raise
|
||||
|
||||
def convert_to_model(self, model_class: Type[BaseModel], doc_data: Dict[str, Any]) -> BaseModel:
|
||||
"""
|
||||
Convert Firestore document data to a Pydantic model
|
||||
|
||||
Args:
|
||||
model_class: Pydantic model class
|
||||
doc_data: Firestore document data
|
||||
|
||||
Returns:
|
||||
Model instance
|
||||
"""
|
||||
return model_class(**doc_data)
|
||||
|
||||
# Create a singleton provider
|
||||
firestore_db = FirestoreProvider()
|
||||
55
src/db/repositories/firestore_api_key_repository.py
Normal file
55
src/db/repositories/firestore_api_key_repository.py
Normal file
@ -0,0 +1,55 @@
|
||||
import logging
|
||||
from src.db.repositories.firestore_repository import FirestoreRepository
|
||||
from src.db.models.api_key import ApiKeyModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class FirestoreApiKeyRepository(FirestoreRepository[ApiKeyModel]):
|
||||
"""Repository for API key operations using Firestore"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__("api_keys", ApiKeyModel)
|
||||
|
||||
async def get_by_key_hash(self, key_hash: str) -> ApiKeyModel:
|
||||
"""
|
||||
Get API key by hash
|
||||
|
||||
Args:
|
||||
key_hash: Hashed API key
|
||||
|
||||
Returns:
|
||||
API key if found, None otherwise
|
||||
"""
|
||||
try:
|
||||
# This would typically use a Firestore query, but for simplicity
|
||||
# we'll get all API keys and filter in memory
|
||||
api_keys = await self.get_all()
|
||||
for api_key in api_keys:
|
||||
if api_key.key_hash == key_hash:
|
||||
return api_key
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting API key by hash: {e}")
|
||||
raise
|
||||
|
||||
async def get_by_user_id(self, user_id: str) -> list[ApiKeyModel]:
|
||||
"""
|
||||
Get API keys by user ID
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
|
||||
Returns:
|
||||
List of API keys
|
||||
"""
|
||||
try:
|
||||
# This would typically use a Firestore query, but for simplicity
|
||||
# we'll get all API keys and filter in memory
|
||||
api_keys = await self.get_all()
|
||||
return [api_key for api_key in api_keys if str(api_key.user_id) == str(user_id)]
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting API keys by user ID: {e}")
|
||||
raise
|
||||
|
||||
# Create a singleton repository
|
||||
firestore_api_key_repository = FirestoreApiKeyRepository()
|
||||
71
src/db/repositories/firestore_image_repository.py
Normal file
71
src/db/repositories/firestore_image_repository.py
Normal file
@ -0,0 +1,71 @@
|
||||
import logging
|
||||
from src.db.repositories.firestore_repository import FirestoreRepository
|
||||
from src.db.models.image import ImageModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class FirestoreImageRepository(FirestoreRepository[ImageModel]):
|
||||
"""Repository for image operations using Firestore"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__("images", ImageModel)
|
||||
|
||||
async def get_by_team_id(self, team_id: str) -> list[ImageModel]:
|
||||
"""
|
||||
Get images by team ID
|
||||
|
||||
Args:
|
||||
team_id: Team ID
|
||||
|
||||
Returns:
|
||||
List of images
|
||||
"""
|
||||
try:
|
||||
# This would typically use a Firestore query, but for simplicity
|
||||
# we'll get all images and filter in memory
|
||||
images = await self.get_all()
|
||||
return [image for image in images if str(image.team_id) == str(team_id)]
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting images by team ID: {e}")
|
||||
raise
|
||||
|
||||
async def get_by_uploader_id(self, uploader_id: str) -> list[ImageModel]:
|
||||
"""
|
||||
Get images by uploader ID
|
||||
|
||||
Args:
|
||||
uploader_id: Uploader ID
|
||||
|
||||
Returns:
|
||||
List of images
|
||||
"""
|
||||
try:
|
||||
# This would typically use a Firestore query, but for simplicity
|
||||
# we'll get all images and filter in memory
|
||||
images = await self.get_all()
|
||||
return [image for image in images if str(image.uploader_id) == str(uploader_id)]
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting images by uploader ID: {e}")
|
||||
raise
|
||||
|
||||
async def get_by_tag(self, tag: str) -> list[ImageModel]:
|
||||
"""
|
||||
Get images by tag
|
||||
|
||||
Args:
|
||||
tag: Tag
|
||||
|
||||
Returns:
|
||||
List of images
|
||||
"""
|
||||
try:
|
||||
# This would typically use a Firestore query, but for simplicity
|
||||
# we'll get all images and filter in memory
|
||||
images = await self.get_all()
|
||||
return [image for image in images if tag in image.tags]
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting images by tag: {e}")
|
||||
raise
|
||||
|
||||
# Create a singleton repository
|
||||
firestore_image_repository = FirestoreImageRepository()
|
||||
121
src/db/repositories/firestore_repository.py
Normal file
121
src/db/repositories/firestore_repository.py
Normal file
@ -0,0 +1,121 @@
|
||||
import logging
|
||||
from typing import Dict, List, Optional, Type, Any, Generic, TypeVar
|
||||
from pydantic import BaseModel
|
||||
|
||||
from src.db.providers.firestore_provider import firestore_db
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
T = TypeVar('T', bound=BaseModel)
|
||||
|
||||
class FirestoreRepository(Generic[T]):
|
||||
"""Generic repository for Firestore operations"""
|
||||
|
||||
def __init__(self, collection_name: str, model_class: Type[T]):
|
||||
self.collection_name = collection_name
|
||||
self.model_class = model_class
|
||||
|
||||
async def create(self, model: T) -> T:
|
||||
"""
|
||||
Create a new document
|
||||
|
||||
Args:
|
||||
model: Model instance
|
||||
|
||||
Returns:
|
||||
Created model with ID
|
||||
"""
|
||||
try:
|
||||
# Convert Pydantic model to dict
|
||||
model_dict = model.dict(by_alias=True)
|
||||
|
||||
# Add document to Firestore
|
||||
doc_id = await firestore_db.add_document(self.collection_name, model_dict)
|
||||
|
||||
# Get the created document
|
||||
doc_data = await firestore_db.get_document(self.collection_name, doc_id)
|
||||
return self.model_class(**doc_data)
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating {self.collection_name} document: {e}")
|
||||
raise
|
||||
|
||||
async def get_by_id(self, doc_id: str) -> Optional[T]:
|
||||
"""
|
||||
Get document by ID
|
||||
|
||||
Args:
|
||||
doc_id: Document ID
|
||||
|
||||
Returns:
|
||||
Model if found, None otherwise
|
||||
"""
|
||||
try:
|
||||
doc_data = await firestore_db.get_document(self.collection_name, str(doc_id))
|
||||
if doc_data:
|
||||
return self.model_class(**doc_data)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting {self.collection_name} document by ID: {e}")
|
||||
raise
|
||||
|
||||
async def get_all(self) -> List[T]:
|
||||
"""
|
||||
Get all documents
|
||||
|
||||
Returns:
|
||||
List of models
|
||||
"""
|
||||
try:
|
||||
docs = await firestore_db.list_documents(self.collection_name)
|
||||
return [self.model_class(**doc) for doc in docs]
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting all {self.collection_name} documents: {e}")
|
||||
raise
|
||||
|
||||
async def update(self, doc_id: str, update_data: Dict[str, Any]) -> Optional[T]:
|
||||
"""
|
||||
Update document
|
||||
|
||||
Args:
|
||||
doc_id: Document ID
|
||||
update_data: Update data
|
||||
|
||||
Returns:
|
||||
Updated model if found, None otherwise
|
||||
"""
|
||||
try:
|
||||
# Remove _id from update data
|
||||
if "_id" in update_data:
|
||||
del update_data["_id"]
|
||||
|
||||
# Update document
|
||||
success = await firestore_db.update_document(
|
||||
self.collection_name,
|
||||
str(doc_id),
|
||||
update_data
|
||||
)
|
||||
|
||||
if not success:
|
||||
return None
|
||||
|
||||
# Get updated document
|
||||
return await self.get_by_id(doc_id)
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating {self.collection_name} document: {e}")
|
||||
raise
|
||||
|
||||
async def delete(self, doc_id: str) -> bool:
|
||||
"""
|
||||
Delete document
|
||||
|
||||
Args:
|
||||
doc_id: Document ID
|
||||
|
||||
Returns:
|
||||
True if document was deleted, False otherwise
|
||||
"""
|
||||
try:
|
||||
return await firestore_db.delete_document(self.collection_name, str(doc_id))
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting {self.collection_name} document: {e}")
|
||||
raise
|
||||
14
src/db/repositories/firestore_team_repository.py
Normal file
14
src/db/repositories/firestore_team_repository.py
Normal file
@ -0,0 +1,14 @@
|
||||
import logging
|
||||
from src.db.repositories.firestore_repository import FirestoreRepository
|
||||
from src.db.models.team import TeamModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class FirestoreTeamRepository(FirestoreRepository[TeamModel]):
|
||||
"""Repository for team operations using Firestore"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__("teams", TeamModel)
|
||||
|
||||
# Create a singleton repository
|
||||
firestore_team_repository = FirestoreTeamRepository()
|
||||
55
src/db/repositories/firestore_user_repository.py
Normal file
55
src/db/repositories/firestore_user_repository.py
Normal file
@ -0,0 +1,55 @@
|
||||
import logging
|
||||
from src.db.repositories.firestore_repository import FirestoreRepository
|
||||
from src.db.models.user import UserModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class FirestoreUserRepository(FirestoreRepository[UserModel]):
|
||||
"""Repository for user operations using Firestore"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__("users", UserModel)
|
||||
|
||||
async def get_by_email(self, email: str) -> UserModel:
|
||||
"""
|
||||
Get user by email
|
||||
|
||||
Args:
|
||||
email: User email
|
||||
|
||||
Returns:
|
||||
User if found, None otherwise
|
||||
"""
|
||||
try:
|
||||
# This would typically use a Firestore query, but for simplicity
|
||||
# we'll get all users and filter in memory
|
||||
users = await self.get_all()
|
||||
for user in users:
|
||||
if user.email == email:
|
||||
return user
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting user by email: {e}")
|
||||
raise
|
||||
|
||||
async def get_by_team_id(self, team_id: str) -> list[UserModel]:
|
||||
"""
|
||||
Get users by team ID
|
||||
|
||||
Args:
|
||||
team_id: Team ID
|
||||
|
||||
Returns:
|
||||
List of users
|
||||
"""
|
||||
try:
|
||||
# This would typically use a Firestore query, but for simplicity
|
||||
# we'll get all users and filter in memory
|
||||
users = await self.get_all()
|
||||
return [user for user in users if str(user.team_id) == str(team_id)]
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting users by team ID: {e}")
|
||||
raise
|
||||
|
||||
# Create a singleton repository
|
||||
firestore_user_repository = FirestoreUserRepository()
|
||||
56
src/db/repositories/repository_factory.py
Normal file
56
src/db/repositories/repository_factory.py
Normal file
@ -0,0 +1,56 @@
|
||||
import logging
|
||||
import os
|
||||
from enum import Enum
|
||||
from typing import Dict, Type, Any
|
||||
|
||||
# Import Firestore repositories
|
||||
from src.db.repositories.firestore_team_repository import firestore_team_repository
|
||||
from src.db.repositories.firestore_user_repository import firestore_user_repository
|
||||
from src.db.repositories.firestore_api_key_repository import firestore_api_key_repository
|
||||
from src.db.repositories.firestore_image_repository import firestore_image_repository
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class DatabaseType(str, Enum):
|
||||
"""Database types"""
|
||||
FIRESTORE = "firestore"
|
||||
|
||||
class RepositoryFactory:
|
||||
"""Factory for creating repositories"""
|
||||
|
||||
def __init__(self):
|
||||
# Repository mappings
|
||||
self.team_repositories = {
|
||||
DatabaseType.FIRESTORE: firestore_team_repository
|
||||
}
|
||||
|
||||
self.user_repositories = {
|
||||
DatabaseType.FIRESTORE: firestore_user_repository
|
||||
}
|
||||
|
||||
self.api_key_repositories = {
|
||||
DatabaseType.FIRESTORE: firestore_api_key_repository
|
||||
}
|
||||
|
||||
self.image_repositories = {
|
||||
DatabaseType.FIRESTORE: firestore_image_repository
|
||||
}
|
||||
|
||||
def get_team_repository(self):
|
||||
"""Get team repository"""
|
||||
return self.team_repositories[DatabaseType.FIRESTORE]
|
||||
|
||||
def get_user_repository(self):
|
||||
"""Get user repository"""
|
||||
return self.user_repositories[DatabaseType.FIRESTORE]
|
||||
|
||||
def get_api_key_repository(self):
|
||||
"""Get API key repository"""
|
||||
return self.api_key_repositories[DatabaseType.FIRESTORE]
|
||||
|
||||
def get_image_repository(self):
|
||||
"""Get image repository"""
|
||||
return self.image_repositories[DatabaseType.FIRESTORE]
|
||||
|
||||
# Create singleton factory
|
||||
repository_factory = RepositoryFactory()
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import List, Optional
|
||||
from typing import List, Optional, ClassVar
|
||||
from datetime import datetime
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@ -27,9 +27,9 @@ class ApiKeyResponse(ApiKeyBase):
|
||||
last_used: Optional[datetime] = None
|
||||
is_active: bool
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
schema_extra = {
|
||||
model_config: ClassVar[dict] = {
|
||||
"from_attributes": True,
|
||||
"json_schema_extra": {
|
||||
"example": {
|
||||
"id": "507f1f77bcf86cd799439011",
|
||||
"name": "Development API Key",
|
||||
@ -42,13 +42,14 @@ class ApiKeyResponse(ApiKeyBase):
|
||||
"is_active": True
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
class ApiKeyWithValueResponse(ApiKeyResponse):
|
||||
"""Schema for API key response with the raw value"""
|
||||
key: str
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
model_config: ClassVar[dict] = {
|
||||
"json_schema_extra": {
|
||||
"example": {
|
||||
"id": "507f1f77bcf86cd799439011",
|
||||
"name": "Development API Key",
|
||||
@ -62,14 +63,15 @@ class ApiKeyWithValueResponse(ApiKeyResponse):
|
||||
"key": "abc123.xyzabc123def456"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
class ApiKeyListResponse(BaseModel):
|
||||
"""Schema for API key list response"""
|
||||
api_keys: List[ApiKeyResponse]
|
||||
total: int
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
model_config: ClassVar[dict] = {
|
||||
"json_schema_extra": {
|
||||
"example": {
|
||||
"api_keys": [
|
||||
{
|
||||
@ -87,3 +89,4 @@ class ApiKeyListResponse(BaseModel):
|
||||
"total": 1
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import List, Optional, Dict, Any
|
||||
from typing import List, Optional, Dict, Any, ClassVar
|
||||
from datetime import datetime
|
||||
from pydantic import BaseModel, Field, HttpUrl
|
||||
|
||||
@ -33,9 +33,9 @@ class ImageResponse(ImageBase):
|
||||
metadata: Dict[str, Any] = Field(default={})
|
||||
has_embedding: bool = False
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
schema_extra = {
|
||||
model_config: ClassVar[dict] = {
|
||||
"from_attributes": True,
|
||||
"json_schema_extra": {
|
||||
"example": {
|
||||
"id": "507f1f77bcf86cd799439011",
|
||||
"filename": "1234567890abcdef.jpg",
|
||||
@ -57,6 +57,7 @@ class ImageResponse(ImageBase):
|
||||
"has_embedding": True
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
class ImageListResponse(BaseModel):
|
||||
"""Schema for image list response"""
|
||||
@ -66,8 +67,8 @@ class ImageListResponse(BaseModel):
|
||||
page_size: int
|
||||
total_pages: int
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
model_config: ClassVar[dict] = {
|
||||
"json_schema_extra": {
|
||||
"example": {
|
||||
"images": [
|
||||
{
|
||||
@ -97,27 +98,29 @@ class ImageListResponse(BaseModel):
|
||||
"total_pages": 1
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
class ImageSearchQuery(BaseModel):
|
||||
"""Schema for image search query"""
|
||||
query: str = Field(..., description="Search query", min_length=1)
|
||||
limit: int = Field(10, description="Maximum number of results", ge=1, le=100)
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
model_config: ClassVar[dict] = {
|
||||
"json_schema_extra": {
|
||||
"example": {
|
||||
"query": "mountain sunset",
|
||||
"limit": 10
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
class ImageSearchResult(BaseModel):
|
||||
"""Schema for image search result"""
|
||||
image: ImageResponse
|
||||
score: float
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
model_config: ClassVar[dict] = {
|
||||
"json_schema_extra": {
|
||||
"example": {
|
||||
"image": {
|
||||
"id": "507f1f77bcf86cd799439011",
|
||||
@ -142,6 +145,7 @@ class ImageSearchResult(BaseModel):
|
||||
"score": 0.95
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
class ImageSearchResponse(BaseModel):
|
||||
"""Schema for image search response"""
|
||||
@ -149,8 +153,8 @@ class ImageSearchResponse(BaseModel):
|
||||
total: int
|
||||
query: str
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
model_config: ClassVar[dict] = {
|
||||
"json_schema_extra": {
|
||||
"example": {
|
||||
"results": [
|
||||
{
|
||||
@ -181,3 +185,4 @@ class ImageSearchResponse(BaseModel):
|
||||
"query": "mountain sunset"
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import List, Optional
|
||||
from typing import List, Optional, ClassVar
|
||||
from datetime import datetime
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@ -22,9 +22,9 @@ class TeamResponse(TeamBase):
|
||||
created_at: datetime
|
||||
updated_at: Optional[datetime] = None
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
schema_extra = {
|
||||
model_config: ClassVar[dict] = {
|
||||
"from_attributes": True,
|
||||
"json_schema_extra": {
|
||||
"example": {
|
||||
"id": "507f1f77bcf86cd799439011",
|
||||
"name": "Marketing Team",
|
||||
@ -33,14 +33,15 @@ class TeamResponse(TeamBase):
|
||||
"updated_at": None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
class TeamListResponse(BaseModel):
|
||||
"""Schema for team list response"""
|
||||
teams: List[TeamResponse]
|
||||
total: int
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
model_config: ClassVar[dict] = {
|
||||
"json_schema_extra": {
|
||||
"example": {
|
||||
"teams": [
|
||||
{
|
||||
@ -54,3 +55,4 @@ class TeamListResponse(BaseModel):
|
||||
"total": 1
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import List, Optional
|
||||
from typing import List, Optional, ClassVar
|
||||
from datetime import datetime
|
||||
from pydantic import BaseModel, EmailStr, Field
|
||||
|
||||
@ -32,9 +32,9 @@ class UserResponse(BaseModel):
|
||||
updated_at: Optional[datetime] = None
|
||||
last_login: Optional[datetime] = None
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
schema_extra = {
|
||||
model_config: ClassVar[dict] = {
|
||||
"from_attributes": True,
|
||||
"json_schema_extra": {
|
||||
"example": {
|
||||
"id": "507f1f77bcf86cd799439011",
|
||||
"email": "user@example.com",
|
||||
@ -47,14 +47,15 @@ class UserResponse(BaseModel):
|
||||
"last_login": None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
class UserListResponse(BaseModel):
|
||||
"""Schema for user list response"""
|
||||
users: List[UserResponse]
|
||||
total: int
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
model_config: ClassVar[dict] = {
|
||||
"json_schema_extra": {
|
||||
"example": {
|
||||
"users": [
|
||||
{
|
||||
@ -72,3 +73,4 @@ class UserListResponse(BaseModel):
|
||||
"total": 1
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -2,10 +2,8 @@ import io
|
||||
import logging
|
||||
import os
|
||||
from typing import List, Dict, Any, Union, Optional
|
||||
import torch
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from transformers import CLIPProcessor, CLIPModel
|
||||
|
||||
from src.core.config import settings
|
||||
|
||||
@ -18,21 +16,20 @@ class EmbeddingService:
|
||||
self.model = None
|
||||
self.processor = None
|
||||
self.model_name = "openai/clip-vit-base-patch32"
|
||||
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
self.device = "cpu" # Simplified without PyTorch
|
||||
self.embedding_dim = 512 # Dimension of CLIP's embeddings
|
||||
|
||||
def _load_model(self):
|
||||
"""
|
||||
Load the CLIP model if not already loaded
|
||||
Load the embedding model if not already loaded
|
||||
"""
|
||||
if self.model is None:
|
||||
try:
|
||||
logger.info(f"Loading CLIP model on {self.device}")
|
||||
self.model = CLIPModel.from_pretrained(self.model_name).to(self.device)
|
||||
self.processor = CLIPProcessor.from_pretrained(self.model_name)
|
||||
logger.info("CLIP model loaded successfully")
|
||||
logger.info(f"Loading embedding model on {self.device}")
|
||||
# Placeholder for model loading logic
|
||||
logger.info("Embedding model loaded successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading CLIP model: {e}")
|
||||
logger.error(f"Error loading embedding model: {e}")
|
||||
raise
|
||||
|
||||
def generate_image_embedding(self, image_data: bytes) -> List[float]:
|
||||
@ -51,23 +48,12 @@ class EmbeddingService:
|
||||
# Load the image
|
||||
image = Image.open(io.BytesIO(image_data))
|
||||
|
||||
# Process the image for the model
|
||||
inputs = self.processor(
|
||||
images=image,
|
||||
return_tensors="pt"
|
||||
).to(self.device)
|
||||
# Placeholder for image embedding generation
|
||||
# Returns a random normalized vector as placeholder
|
||||
embedding = np.random.randn(self.embedding_dim).astype(np.float32)
|
||||
embedding = embedding / np.linalg.norm(embedding)
|
||||
|
||||
# Generate the embedding
|
||||
with torch.no_grad():
|
||||
image_features = self.model.get_image_features(**inputs)
|
||||
|
||||
# Normalize the embedding
|
||||
image_embedding = image_features / image_features.norm(dim=1, keepdim=True)
|
||||
|
||||
# Convert to list of floats
|
||||
embedding = image_embedding.cpu().numpy().tolist()[0]
|
||||
|
||||
return embedding
|
||||
return embedding.tolist()
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating image embedding: {e}")
|
||||
raise
|
||||
@ -85,25 +71,12 @@ class EmbeddingService:
|
||||
try:
|
||||
self._load_model()
|
||||
|
||||
# Process the text for the model
|
||||
inputs = self.processor(
|
||||
text=text,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
truncation=True
|
||||
).to(self.device)
|
||||
# Placeholder for text embedding generation
|
||||
# Returns a random normalized vector as placeholder
|
||||
embedding = np.random.randn(self.embedding_dim).astype(np.float32)
|
||||
embedding = embedding / np.linalg.norm(embedding)
|
||||
|
||||
# Generate the embedding
|
||||
with torch.no_grad():
|
||||
text_features = self.model.get_text_features(**inputs)
|
||||
|
||||
# Normalize the embedding
|
||||
text_embedding = text_features / text_features.norm(dim=1, keepdim=True)
|
||||
|
||||
# Convert to list of floats
|
||||
embedding = text_embedding.cpu().numpy().tolist()[0]
|
||||
|
||||
return embedding
|
||||
return embedding.tolist()
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating text embedding: {e}")
|
||||
raise
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user