cp
This commit is contained in:
parent
d6dea7ef17
commit
ae219c8111
34
.env.example
Normal file
34
.env.example
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
# Application Settings
|
||||||
|
PROJECT_NAME="Image Management API"
|
||||||
|
API_V1_STR="/api/v1"
|
||||||
|
LOG_LEVEL=INFO
|
||||||
|
|
||||||
|
# CORS Settings
|
||||||
|
CORS_ORIGINS=*
|
||||||
|
|
||||||
|
# Database Settings
|
||||||
|
# Choose database type: "mongodb" or "firestore"
|
||||||
|
DATABASE_TYPE=mongodb
|
||||||
|
|
||||||
|
# MongoDB Settings (used when DATABASE_TYPE=mongodb)
|
||||||
|
DATABASE_URI=mongodb://localhost:27017
|
||||||
|
DATABASE_NAME=imagedb
|
||||||
|
|
||||||
|
# Google Cloud Firestore Settings (used when DATABASE_TYPE=firestore)
|
||||||
|
# Path to service account credentials file (optional, uses application default credentials if not set)
|
||||||
|
GCS_CREDENTIALS_FILE=path/to/credentials.json
|
||||||
|
|
||||||
|
# Google Cloud Storage Settings
|
||||||
|
GCS_BUCKET_NAME=your-bucket-name
|
||||||
|
|
||||||
|
# Security Settings
|
||||||
|
API_KEY_SECRET=super-secret-key-for-development-only
|
||||||
|
API_KEY_EXPIRY_DAYS=365
|
||||||
|
|
||||||
|
# Vector Database Settings (for image embeddings)
|
||||||
|
VECTOR_DB_API_KEY=your-pinecone-api-key
|
||||||
|
VECTOR_DB_ENVIRONMENT=your-pinecone-environment
|
||||||
|
VECTOR_DB_INDEX_NAME=image-embeddings
|
||||||
|
|
||||||
|
# Rate Limiting
|
||||||
|
RATE_LIMIT_PER_MINUTE=100
|
||||||
341
README.md
341
README.md
@ -1,156 +1,187 @@
|
|||||||
# SEREACT - Secure Image Management API
|
# SEREACT - Secure Image Management API
|
||||||
|
|
||||||
SEREACT is a secure API for storing, organizing, and retrieving images with advanced search capabilities.
|
SEREACT is a secure API for storing, organizing, and retrieving images with advanced search capabilities.
|
||||||
|
|
||||||
## Features
|
## Features
|
||||||
|
|
||||||
- Secure image storage in Google Cloud Storage
|
- Secure image storage in Google Cloud Storage
|
||||||
- Team-based organization and access control
|
- Team-based organization and access control
|
||||||
- API key authentication
|
- API key authentication
|
||||||
- Semantic search using image embeddings
|
- Semantic search using image embeddings
|
||||||
- Metadata extraction and storage
|
- Metadata extraction and storage
|
||||||
- Image processing capabilities
|
- Image processing capabilities
|
||||||
- Multi-team support
|
- Multi-team support
|
||||||
|
|
||||||
## Architecture
|
## Architecture
|
||||||
|
|
||||||
```
|
```
|
||||||
sereact/
|
sereact/
|
||||||
├── images/ # Sample images for testing
|
├── images/ # Sample images for testing
|
||||||
├── sereact/ # Main application code
|
├── sereact/ # Main application code
|
||||||
│ ├── deployment/ # Deployment configurations
|
│ ├── deployment/ # Deployment configurations
|
||||||
│ │ └── cloud-run/ # Google Cloud Run configuration
|
│ │ └── cloud-run/ # Google Cloud Run configuration
|
||||||
│ ├── docs/ # Documentation
|
│ ├── docs/ # Documentation
|
||||||
│ │ └── api/ # API documentation
|
│ │ └── api/ # API documentation
|
||||||
│ ├── scripts/ # Utility scripts
|
│ ├── scripts/ # Utility scripts
|
||||||
│ ├── src/ # Source code
|
│ ├── src/ # Source code
|
||||||
│ │ ├── api/ # API endpoints
|
│ │ ├── api/ # API endpoints
|
||||||
│ │ │ └── v1/ # API version 1
|
│ │ │ └── v1/ # API version 1
|
||||||
│ │ ├── core/ # Core modules
|
│ │ ├── core/ # Core modules
|
||||||
│ │ ├── db/ # Database models and repositories
|
│ │ ├── db/ # Database models and repositories
|
||||||
│ │ │ ├── models/ # Data models
|
│ │ │ ├── models/ # Data models
|
||||||
│ │ │ └── repositories/ # Database operations
|
│ │ │ └── repositories/ # Database operations
|
||||||
│ │ ├── schemas/ # API schemas (request/response)
|
│ │ ├── schemas/ # API schemas (request/response)
|
||||||
│ │ └── services/ # Business logic services
|
│ │ └── services/ # Business logic services
|
||||||
│ └── tests/ # Test code
|
│ └── tests/ # Test code
|
||||||
│ ├── api/ # API tests
|
│ ├── api/ # API tests
|
||||||
│ ├── db/ # Database tests
|
│ ├── db/ # Database tests
|
||||||
│ └── services/ # Service tests
|
│ └── services/ # Service tests
|
||||||
├── main.py # Application entry point
|
├── main.py # Application entry point
|
||||||
├── requirements.txt # Python dependencies
|
├── requirements.txt # Python dependencies
|
||||||
└── README.md # This file
|
└── README.md # This file
|
||||||
```
|
```
|
||||||
|
|
||||||
## Technology Stack
|
## Technology Stack
|
||||||
|
|
||||||
- FastAPI - Web framework
|
- FastAPI - Web framework
|
||||||
- MongoDB - Database
|
- Firestore - Database
|
||||||
- Google Cloud Storage - Image storage
|
- Google Cloud Storage - Image storage
|
||||||
- Pinecone - Vector database for semantic search
|
- Pinecone - Vector database for semantic search
|
||||||
- CLIP - Image embedding model
|
- CLIP - Image embedding model
|
||||||
- PyTorch - Deep learning framework
|
- NumPy - Scientific computing
|
||||||
- Pydantic - Data validation
|
- Pydantic - Data validation
|
||||||
|
|
||||||
## Setup and Installation
|
## Setup and Installation
|
||||||
|
|
||||||
### Prerequisites
|
### Prerequisites
|
||||||
|
|
||||||
- Python 3.8+
|
- Python 3.8+
|
||||||
- MongoDB
|
- Google Cloud account with Firestore and Storage enabled
|
||||||
- Google Cloud account with Storage enabled
|
- (Optional) Pinecone account for semantic search
|
||||||
- (Optional) Pinecone account for semantic search
|
|
||||||
|
### Installation
|
||||||
### Installation
|
|
||||||
|
1. Clone the repository:
|
||||||
1. Clone the repository:
|
```bash
|
||||||
```bash
|
git clone https://github.com/yourusername/sereact.git
|
||||||
git clone https://github.com/yourusername/sereact.git
|
cd sereact
|
||||||
cd sereact
|
```
|
||||||
```
|
|
||||||
|
2. Create and activate a virtual environment:
|
||||||
2. Create and activate a virtual environment:
|
```bash
|
||||||
```bash
|
python -m venv venv
|
||||||
python -m venv venv
|
source venv/bin/activate # Linux/macOS
|
||||||
source venv/bin/activate # Linux/macOS
|
venv\Scripts\activate # Windows
|
||||||
venv\Scripts\activate # Windows
|
```
|
||||||
```
|
|
||||||
|
3. Install dependencies:
|
||||||
3. Install dependencies:
|
```bash
|
||||||
```bash
|
pip install -r requirements.txt
|
||||||
pip install -r requirements.txt
|
```
|
||||||
```
|
|
||||||
|
4. Create a `.env` file with the following environment variables:
|
||||||
4. Create a `.env` file with the following environment variables:
|
```
|
||||||
```
|
# Firestore
|
||||||
# MongoDB
|
DATABASE_NAME=imagedb
|
||||||
DATABASE_URI=mongodb://localhost:27017
|
FIRESTORE_PROJECT_ID=your-gcp-project-id
|
||||||
DATABASE_NAME=imagedb
|
FIRESTORE_CREDENTIALS_FILE=path/to/firestore-credentials.json
|
||||||
|
|
||||||
# Google Cloud Storage
|
# Google Cloud Storage
|
||||||
GCS_BUCKET_NAME=your-bucket-name
|
GCS_BUCKET_NAME=your-bucket-name
|
||||||
GCS_CREDENTIALS_FILE=path/to/credentials.json
|
GCS_CREDENTIALS_FILE=path/to/credentials.json
|
||||||
|
|
||||||
# Security
|
# Security
|
||||||
API_KEY_SECRET=your-secret-key
|
API_KEY_SECRET=your-secret-key
|
||||||
|
|
||||||
# Vector database (optional)
|
# Vector database (optional)
|
||||||
VECTOR_DB_API_KEY=your-pinecone-api-key
|
VECTOR_DB_API_KEY=your-pinecone-api-key
|
||||||
VECTOR_DB_ENVIRONMENT=your-pinecone-environment
|
VECTOR_DB_ENVIRONMENT=your-pinecone-environment
|
||||||
VECTOR_DB_INDEX_NAME=image-embeddings
|
VECTOR_DB_INDEX_NAME=image-embeddings
|
||||||
```
|
```
|
||||||
|
|
||||||
5. Run the application:
|
5. Run the application:
|
||||||
```bash
|
```bash
|
||||||
uvicorn main:app --reload
|
uvicorn main:app --reload
|
||||||
```
|
```
|
||||||
|
|
||||||
6. Visit `http://localhost:8000/docs` in your browser to access the API documentation.
|
6. Visit `http://localhost:8000/docs` in your browser to access the API documentation.
|
||||||
|
|
||||||
## API Endpoints
|
## API Endpoints
|
||||||
|
|
||||||
The API provides the following main endpoints:
|
The API provides the following main endpoints:
|
||||||
|
|
||||||
- `/api/v1/auth/*` - Authentication and API key management
|
- `/api/v1/auth/*` - Authentication and API key management
|
||||||
- `/api/v1/teams/*` - Team management
|
- `/api/v1/teams/*` - Team management
|
||||||
- `/api/v1/users/*` - User management
|
- `/api/v1/users/*` - User management
|
||||||
- `/api/v1/images/*` - Image upload, download, and management
|
- `/api/v1/images/*` - Image upload, download, and management
|
||||||
- `/api/v1/search/*` - Image search functionality
|
- `/api/v1/search/*` - Image search functionality
|
||||||
|
|
||||||
Refer to the Swagger UI documentation at `/docs` for detailed endpoint information.
|
Refer to the Swagger UI documentation at `/docs` for detailed endpoint information.
|
||||||
|
|
||||||
## Development
|
## Development
|
||||||
|
|
||||||
### Running Tests
|
### Running Tests
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
pytest
|
pytest
|
||||||
```
|
```
|
||||||
|
|
||||||
### Creating a New API Version
|
### Creating a New API Version
|
||||||
|
|
||||||
1. Create a new package under `src/api/` (e.g., `v2`)
|
1. Create a new package under `src/api/` (e.g., `v2`)
|
||||||
2. Implement new endpoints
|
2. Implement new endpoints
|
||||||
3. Update the main.py file to include the new routers
|
3. Update the main.py file to include the new routers
|
||||||
|
|
||||||
## Deployment
|
## Deployment
|
||||||
|
|
||||||
### Google Cloud Run
|
### Google Cloud Run
|
||||||
|
|
||||||
1. Build the Docker image:
|
1. Build the Docker image:
|
||||||
```bash
|
```bash
|
||||||
docker build -t gcr.io/your-project/sereact .
|
docker build -t gcr.io/your-project/sereact .
|
||||||
```
|
```
|
||||||
|
|
||||||
2. Push to Google Container Registry:
|
2. Push to Google Container Registry:
|
||||||
```bash
|
```bash
|
||||||
docker push gcr.io/your-project/sereact
|
docker push gcr.io/your-project/sereact
|
||||||
```
|
```
|
||||||
|
|
||||||
3. Deploy to Cloud Run:
|
3. Deploy to Cloud Run:
|
||||||
```bash
|
```bash
|
||||||
gcloud run deploy sereact --image gcr.io/your-project/sereact --platform managed
|
gcloud run deploy sereact --image gcr.io/your-project/sereact --platform managed
|
||||||
```
|
```
|
||||||
|
|
||||||
## License
|
## Local Development with Docker Compose
|
||||||
|
|
||||||
|
To run the application locally using Docker Compose:
|
||||||
|
|
||||||
|
1. Make sure you have Docker and Docker Compose installed
|
||||||
|
2. Run the following command in the project root:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
docker compose up
|
||||||
|
```
|
||||||
|
|
||||||
|
This will:
|
||||||
|
- Build the API container based on the Dockerfile
|
||||||
|
- Mount your local codebase into the container for live reloading
|
||||||
|
- Mount your Firestore credentials for authentication
|
||||||
|
- Expose the API on http://localhost:8000
|
||||||
|
|
||||||
|
To stop the containers:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
docker compose down
|
||||||
|
```
|
||||||
|
|
||||||
|
To rebuild containers after making changes to the Dockerfile or requirements:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
docker compose up --build
|
||||||
|
```
|
||||||
|
|
||||||
|
## Additional Information
|
||||||
|
|
||||||
|
## License
|
||||||
|
|
||||||
This project is licensed under the MIT License - see the LICENSE file for details.
|
This project is licensed under the MIT License - see the LICENSE file for details.
|
||||||
17
docker-compose.yml
Normal file
17
docker-compose.yml
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
version: '3.8'
|
||||||
|
|
||||||
|
services:
|
||||||
|
api:
|
||||||
|
build: .
|
||||||
|
ports:
|
||||||
|
- "8000:8000"
|
||||||
|
volumes:
|
||||||
|
- .:/app
|
||||||
|
- ${GOOGLE_APPLICATION_CREDENTIALS:-./firestore-credentials.json}:/app/firestore-credentials.json:ro
|
||||||
|
environment:
|
||||||
|
- PYTHONUNBUFFERED=1
|
||||||
|
- ENVIRONMENT=development
|
||||||
|
- DATABASE_NAME=imagedb
|
||||||
|
- FIRESTORE_CREDENTIALS_FILE=/app/firestore-credentials.json
|
||||||
|
- GOOGLE_APPLICATION_CREDENTIALS=/app/firestore-credentials.json
|
||||||
|
command: uvicorn main:app --host 0.0.0.0 --port 8000 --reload
|
||||||
@ -1,21 +1,18 @@
|
|||||||
fastapi==0.104.1
|
fastapi==0.104.1
|
||||||
uvicorn==0.23.2
|
uvicorn==0.23.2
|
||||||
pydantic==2.4.2
|
pydantic==2.4.2
|
||||||
pydantic-settings==2.0.3
|
pydantic-settings==2.0.3
|
||||||
python-dotenv==1.0.0
|
python-dotenv==1.0.0
|
||||||
google-cloud-storage==2.12.0
|
google-cloud-storage==2.12.0
|
||||||
google-cloud-vision==3.4.5
|
google-cloud-vision==3.4.5
|
||||||
pymongo==4.5.0
|
google-cloud-firestore==2.11.1
|
||||||
motor==3.3.1
|
python-multipart==0.0.6
|
||||||
python-multipart==0.0.6
|
python-jose==3.3.0
|
||||||
python-jose==3.3.0
|
passlib==1.7.4
|
||||||
passlib==1.7.4
|
tenacity==8.2.3
|
||||||
tenacity==8.2.3
|
pytest==7.4.3
|
||||||
pytest==7.4.3
|
httpx==0.25.1
|
||||||
httpx==0.25.1
|
pinecone-client==2.2.4
|
||||||
pinecone-client==2.2.4
|
pillow==10.1.0
|
||||||
pillow==10.1.0
|
python-slugify==8.0.1
|
||||||
clip==0.2.0
|
email-validator==2.1.0.post1
|
||||||
torch==2.1.0
|
|
||||||
transformers==4.35.0
|
|
||||||
python-slugify==8.0.1
|
|
||||||
|
|||||||
188
scripts/README.md
Normal file
188
scripts/README.md
Normal file
@ -0,0 +1,188 @@
|
|||||||
|
# Build and Deployment Scripts
|
||||||
|
|
||||||
|
This directory contains scripts for building and deploying the Sereact API application.
|
||||||
|
|
||||||
|
## Prerequisites
|
||||||
|
|
||||||
|
- Docker installed and running
|
||||||
|
- For deployment: access to a container registry (e.g., DockerHub, Google Container Registry)
|
||||||
|
- For Cloud Run deployment: Google Cloud SDK (`gcloud`) installed and configured
|
||||||
|
|
||||||
|
## Scripts
|
||||||
|
|
||||||
|
### Build Script (`build.sh`)
|
||||||
|
|
||||||
|
Builds the Docker image for the Sereact API.
|
||||||
|
|
||||||
|
**Usage:**
|
||||||
|
```bash
|
||||||
|
# Basic usage (builds with default settings)
|
||||||
|
./scripts/build.sh
|
||||||
|
|
||||||
|
# Customize the image name
|
||||||
|
IMAGE_NAME=my-custom-name ./scripts/build.sh
|
||||||
|
|
||||||
|
# Customize the image tag
|
||||||
|
IMAGE_TAG=v1.0.0 ./scripts/build.sh
|
||||||
|
|
||||||
|
# Use a custom registry
|
||||||
|
REGISTRY=gcr.io/my-project ./scripts/build.sh
|
||||||
|
```
|
||||||
|
|
||||||
|
**Environment Variables:**
|
||||||
|
- `IMAGE_NAME`: Name for the Docker image (default: "sereact-api")
|
||||||
|
- `IMAGE_TAG`: Tag for the Docker image (default: "latest")
|
||||||
|
- `REGISTRY`: Container registry to use (default: empty, using DockerHub)
|
||||||
|
|
||||||
|
### Deploy Script (`deploy.sh`)
|
||||||
|
|
||||||
|
Pushes the built Docker image to a container registry and optionally deploys to Google Cloud Run.
|
||||||
|
|
||||||
|
**Usage:**
|
||||||
|
```bash
|
||||||
|
# Push to registry only
|
||||||
|
./scripts/deploy.sh
|
||||||
|
|
||||||
|
# Push to registry and deploy to Cloud Run
|
||||||
|
DEPLOY_TO_CLOUD_RUN=true PROJECT_ID=my-project-id ./scripts/deploy.sh
|
||||||
|
|
||||||
|
# Customize deployment settings
|
||||||
|
DEPLOY_TO_CLOUD_RUN=true PROJECT_ID=my-project-id REGION=us-west1 SERVICE_NAME=my-api ./scripts/deploy.sh
|
||||||
|
```
|
||||||
|
|
||||||
|
**Environment Variables:**
|
||||||
|
All variables from the build script, plus:
|
||||||
|
- `DEPLOY_TO_CLOUD_RUN`: Set to "true" to deploy to Cloud Run (default: "false")
|
||||||
|
- `PROJECT_ID`: Google Cloud project ID (required for Cloud Run deployment)
|
||||||
|
- `REGION`: Google Cloud region (default: "us-central1")
|
||||||
|
- `SERVICE_NAME`: Name for the Cloud Run service (default: "sereact-api")
|
||||||
|
|
||||||
|
### Cloud Run Deployment Script (`deploy-to-cloud-run.sh`)
|
||||||
|
|
||||||
|
Deploys the application to Google Cloud Run using the service configuration file.
|
||||||
|
|
||||||
|
**Usage:**
|
||||||
|
```bash
|
||||||
|
# Deploy using existing service.yaml
|
||||||
|
PROJECT_ID=my-project-id ./scripts/deploy-to-cloud-run.sh
|
||||||
|
|
||||||
|
# Build, push, and deploy in one command
|
||||||
|
PROJECT_ID=my-project-id BUILD=true PUSH=true ./scripts/deploy-to-cloud-run.sh
|
||||||
|
|
||||||
|
# Customize the deployment
|
||||||
|
PROJECT_ID=my-project-id REGION=us-west1 IMAGE_TAG=v1.0.0 ./scripts/deploy-to-cloud-run.sh
|
||||||
|
```
|
||||||
|
|
||||||
|
**Environment Variables:**
|
||||||
|
- `PROJECT_ID`: Google Cloud project ID (required)
|
||||||
|
- `REGION`: Google Cloud region (default: "us-central1")
|
||||||
|
- `SERVICE_CONFIG`: Path to the service configuration file (default: "deployment/cloud-run/service.yaml")
|
||||||
|
- `IMAGE_NAME`: Name for the Docker image (default: "sereact-api")
|
||||||
|
- `IMAGE_TAG`: Tag for the Docker image (default: "latest")
|
||||||
|
- `REGISTRY`: Container registry to use (default: "gcr.io")
|
||||||
|
- `BUILD`: Set to "true" to build the image before deployment (default: "false")
|
||||||
|
- `PUSH`: Set to "true" to push the image before deployment (default: "false")
|
||||||
|
|
||||||
|
## Example Workflows
|
||||||
|
|
||||||
|
### Basic workflow:
|
||||||
|
```bash
|
||||||
|
# Build and tag with version
|
||||||
|
IMAGE_TAG=v1.0.0 ./scripts/build.sh
|
||||||
|
|
||||||
|
# Deploy to Cloud Run
|
||||||
|
DEPLOY_TO_CLOUD_RUN=true PROJECT_ID=my-project-id IMAGE_TAG=v1.0.0 ./scripts/deploy.sh
|
||||||
|
```
|
||||||
|
|
||||||
|
### Using the Cloud Run config file:
|
||||||
|
```bash
|
||||||
|
# Build and deploy in one step
|
||||||
|
PROJECT_ID=my-project-id BUILD=true PUSH=true ./scripts/deploy-to-cloud-run.sh
|
||||||
|
```
|
||||||
|
|
||||||
|
# Scripts Documentation
|
||||||
|
|
||||||
|
This directory contains utility scripts for the SEREACT application.
|
||||||
|
|
||||||
|
## Database Seeding Scripts
|
||||||
|
|
||||||
|
### `seed_firestore.py`
|
||||||
|
|
||||||
|
This script initializes and seeds a Google Cloud Firestore database with initial data for the SEREACT application. It creates teams, users, API keys, and sample image metadata.
|
||||||
|
|
||||||
|
#### Requirements
|
||||||
|
|
||||||
|
- Google Cloud project with Firestore enabled
|
||||||
|
- Google Cloud credentials configured on your machine
|
||||||
|
- Python 3.8+
|
||||||
|
- Required Python packages (listed in `requirements.txt`)
|
||||||
|
|
||||||
|
#### Setup
|
||||||
|
|
||||||
|
1. Make sure you have the Google Cloud SDK installed and configured with access to your project:
|
||||||
|
```bash
|
||||||
|
gcloud auth login
|
||||||
|
gcloud config set project YOUR_PROJECT_ID
|
||||||
|
```
|
||||||
|
|
||||||
|
2. If not using application default credentials, create a service account key file:
|
||||||
|
```bash
|
||||||
|
gcloud iam service-accounts create sereact-app
|
||||||
|
gcloud projects add-iam-policy-binding YOUR_PROJECT_ID --member="serviceAccount:sereact-app@YOUR_PROJECT_ID.iam.gserviceaccount.com" --role="roles/datastore.user"
|
||||||
|
gcloud iam service-accounts keys create credentials.json --iam-account=sereact-app@YOUR_PROJECT_ID.iam.gserviceaccount.com
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Set environment variables:
|
||||||
|
```bash
|
||||||
|
# Windows (CMD)
|
||||||
|
set DATABASE_TYPE=firestore
|
||||||
|
set GCS_CREDENTIALS_FILE=path/to/credentials.json
|
||||||
|
|
||||||
|
# Windows (PowerShell)
|
||||||
|
$env:DATABASE_TYPE="firestore"
|
||||||
|
$env:GCS_CREDENTIALS_FILE="path/to/credentials.json"
|
||||||
|
|
||||||
|
# Linux/macOS
|
||||||
|
export DATABASE_TYPE=firestore
|
||||||
|
export GCS_CREDENTIALS_FILE=path/to/credentials.json
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Usage
|
||||||
|
|
||||||
|
Run the seeding script from the project root directory:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Activate the Python virtual environment
|
||||||
|
source venv/bin/activate # Linux/macOS
|
||||||
|
venv\Scripts\activate # Windows
|
||||||
|
|
||||||
|
# Run the script
|
||||||
|
python scripts/seed_firestore.py
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Generated Data
|
||||||
|
|
||||||
|
The script will create the following data:
|
||||||
|
|
||||||
|
1. **Teams**:
|
||||||
|
- Sereact Development
|
||||||
|
- Marketing Team
|
||||||
|
- Customer Support
|
||||||
|
|
||||||
|
2. **Users**:
|
||||||
|
- Admin User (team: Sereact Development)
|
||||||
|
- Developer User (team: Sereact Development)
|
||||||
|
- Marketing User (team: Marketing Team)
|
||||||
|
- Support User (team: Customer Support)
|
||||||
|
|
||||||
|
3. **API Keys**:
|
||||||
|
- One API key per user (the keys will be output to the console, save them securely)
|
||||||
|
|
||||||
|
4. **Images**:
|
||||||
|
- Sample image metadata (3 images, one for each team)
|
||||||
|
|
||||||
|
#### Notes
|
||||||
|
|
||||||
|
- The script logs the generated API keys to the console. Save these keys somewhere secure as they won't be displayed again.
|
||||||
|
- If you need to re-run the script with existing data, use the `--force` flag to overwrite existing data.
|
||||||
|
- This script only creates metadata entries for images - it does not upload actual files to Google Cloud Storage.
|
||||||
29
scripts/build.sh
Normal file
29
scripts/build.sh
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
set -e
|
||||||
|
|
||||||
|
# Set defaults
|
||||||
|
IMAGE_NAME=${IMAGE_NAME:-"sereact-api"}
|
||||||
|
IMAGE_TAG=${IMAGE_TAG:-"latest"}
|
||||||
|
|
||||||
|
# Allow custom registry (defaults to DockerHub)
|
||||||
|
REGISTRY=${REGISTRY:-""}
|
||||||
|
REGISTRY_PREFIX=""
|
||||||
|
if [ -n "$REGISTRY" ]; then
|
||||||
|
REGISTRY_PREFIX="${REGISTRY}/"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Full image reference
|
||||||
|
FULL_IMAGE_NAME="${REGISTRY_PREFIX}${IMAGE_NAME}:${IMAGE_TAG}"
|
||||||
|
|
||||||
|
echo "Building Docker image: ${FULL_IMAGE_NAME}"
|
||||||
|
|
||||||
|
# Build the Docker image
|
||||||
|
docker build -t "${FULL_IMAGE_NAME}" -f Dockerfile .
|
||||||
|
|
||||||
|
echo "Build completed successfully"
|
||||||
|
echo "Image: ${FULL_IMAGE_NAME}"
|
||||||
|
|
||||||
|
# Print run command for testing locally
|
||||||
|
echo ""
|
||||||
|
echo "To run the image locally:"
|
||||||
|
echo "docker run -p 8000:8000 ${FULL_IMAGE_NAME}"
|
||||||
57
scripts/deploy-to-cloud-run.sh
Normal file
57
scripts/deploy-to-cloud-run.sh
Normal file
@ -0,0 +1,57 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
set -e
|
||||||
|
|
||||||
|
# Default values
|
||||||
|
PROJECT_ID=${PROJECT_ID:-""}
|
||||||
|
REGION=${REGION:-"us-central1"}
|
||||||
|
SERVICE_CONFIG=${SERVICE_CONFIG:-"deployment/cloud-run/service.yaml"}
|
||||||
|
IMAGE_NAME=${IMAGE_NAME:-"sereact-api"}
|
||||||
|
IMAGE_TAG=${IMAGE_TAG:-"latest"}
|
||||||
|
REGISTRY=${REGISTRY:-"gcr.io"}
|
||||||
|
|
||||||
|
# Validate required parameters
|
||||||
|
if [ -z "$PROJECT_ID" ]; then
|
||||||
|
echo "Error: PROJECT_ID environment variable is required"
|
||||||
|
echo "Usage: PROJECT_ID=your-project-id ./scripts/deploy-to-cloud-run.sh"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Full image reference
|
||||||
|
FULL_IMAGE_NAME="${REGISTRY}/${PROJECT_ID}/${IMAGE_NAME}:${IMAGE_TAG}"
|
||||||
|
|
||||||
|
# Check if service config exists
|
||||||
|
if [ ! -f "$SERVICE_CONFIG" ]; then
|
||||||
|
echo "Error: Service configuration file not found at $SERVICE_CONFIG"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Build the image if BUILD=true
|
||||||
|
if [ "${BUILD:-false}" = "true" ]; then
|
||||||
|
echo "Building Docker image: ${FULL_IMAGE_NAME}"
|
||||||
|
docker build -t "${FULL_IMAGE_NAME}" -f Dockerfile .
|
||||||
|
echo "Build completed successfully"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Push the image if PUSH=true
|
||||||
|
if [ "${PUSH:-false}" = "true" ]; then
|
||||||
|
echo "Pushing image to registry: ${FULL_IMAGE_NAME}"
|
||||||
|
docker push "${FULL_IMAGE_NAME}"
|
||||||
|
echo "Image pushed successfully"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Update the image in the service configuration
|
||||||
|
echo "Updating image reference in service configuration..."
|
||||||
|
TMP_CONFIG=$(mktemp)
|
||||||
|
sed "s|image: .*|image: ${FULL_IMAGE_NAME}|g" "$SERVICE_CONFIG" > "$TMP_CONFIG"
|
||||||
|
|
||||||
|
echo "Deploying to Cloud Run using configuration..."
|
||||||
|
gcloud run services replace "$TMP_CONFIG" \
|
||||||
|
--project="$PROJECT_ID" \
|
||||||
|
--region="$REGION" \
|
||||||
|
--platform=managed
|
||||||
|
|
||||||
|
rm "$TMP_CONFIG"
|
||||||
|
|
||||||
|
echo "Deployment completed successfully"
|
||||||
|
echo "Service URL: $(gcloud run services describe sereact --region=${REGION} --project=${PROJECT_ID} --format='value(status.url)')"
|
||||||
|
echo "To view logs: gcloud logging read 'resource.type=cloud_run_revision AND resource.labels.service_name=sereact' --project=$PROJECT_ID --limit=10"
|
||||||
43
scripts/deploy.sh
Normal file
43
scripts/deploy.sh
Normal file
@ -0,0 +1,43 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
set -e
|
||||||
|
|
||||||
|
# Source the build environment to reuse variables
|
||||||
|
source "$(dirname "$0")/build.sh"
|
||||||
|
|
||||||
|
# Push the Docker image to the registry
|
||||||
|
echo "Pushing image: ${FULL_IMAGE_NAME} to registry..."
|
||||||
|
docker push "${FULL_IMAGE_NAME}"
|
||||||
|
echo "Image pushed successfully"
|
||||||
|
|
||||||
|
# Check if we need to deploy to Cloud Run
|
||||||
|
DEPLOY_TO_CLOUD_RUN=${DEPLOY_TO_CLOUD_RUN:-false}
|
||||||
|
|
||||||
|
if [ "$DEPLOY_TO_CLOUD_RUN" = true ]; then
|
||||||
|
echo "Deploying to Cloud Run..."
|
||||||
|
|
||||||
|
# Cloud Run settings
|
||||||
|
PROJECT_ID=${PROJECT_ID:-""}
|
||||||
|
REGION=${REGION:-"us-central1"}
|
||||||
|
SERVICE_NAME=${SERVICE_NAME:-"sereact-api"}
|
||||||
|
|
||||||
|
if [ -z "$PROJECT_ID" ]; then
|
||||||
|
echo "Error: PROJECT_ID environment variable is required for Cloud Run deployment"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Deploy to Cloud Run
|
||||||
|
gcloud run deploy "${SERVICE_NAME}" \
|
||||||
|
--image="${FULL_IMAGE_NAME}" \
|
||||||
|
--platform=managed \
|
||||||
|
--region="${REGION}" \
|
||||||
|
--project="${PROJECT_ID}" \
|
||||||
|
--allow-unauthenticated \
|
||||||
|
--port=8000
|
||||||
|
|
||||||
|
echo "Deployment to Cloud Run completed"
|
||||||
|
echo "Service URL: $(gcloud run services describe ${SERVICE_NAME} --region=${REGION} --project=${PROJECT_ID} --format='value(status.url)')"
|
||||||
|
else
|
||||||
|
echo ""
|
||||||
|
echo "To deploy to Cloud Run:"
|
||||||
|
echo "DEPLOY_TO_CLOUD_RUN=true PROJECT_ID=your-project-id ./scripts/deploy.sh"
|
||||||
|
fi
|
||||||
272
scripts/seed_firestore.py
Normal file
272
scripts/seed_firestore.py
Normal file
@ -0,0 +1,272 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Script to seed the Firestore database with initial data.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import argparse
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
import secrets
|
||||||
|
import hashlib
|
||||||
|
from bson import ObjectId
|
||||||
|
from pydantic import HttpUrl
|
||||||
|
|
||||||
|
# Add the parent directory to the path so we can import from src
|
||||||
|
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
||||||
|
|
||||||
|
from src.db.models.team import TeamModel
|
||||||
|
from src.db.models.user import UserModel
|
||||||
|
from src.db.models.api_key import ApiKeyModel
|
||||||
|
from src.db.models.image import ImageModel
|
||||||
|
from src.db.providers.firestore_provider import firestore_db
|
||||||
|
from src.db.repositories.firestore_team_repository import firestore_team_repository
|
||||||
|
from src.db.repositories.firestore_user_repository import firestore_user_repository
|
||||||
|
from src.db.repositories.firestore_api_key_repository import firestore_api_key_repository
|
||||||
|
from src.db.repositories.firestore_image_repository import firestore_image_repository
|
||||||
|
|
||||||
|
# Configure logging
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||||
|
)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
def generate_api_key(length=32):
|
||||||
|
"""Generate a random API key"""
|
||||||
|
return secrets.token_hex(length)
|
||||||
|
|
||||||
|
def hash_api_key(api_key):
|
||||||
|
"""Hash an API key for storage"""
|
||||||
|
return hashlib.sha256(api_key.encode()).hexdigest()
|
||||||
|
|
||||||
|
async def seed_teams():
|
||||||
|
"""Seed the database with team data"""
|
||||||
|
logger.info("Seeding teams...")
|
||||||
|
|
||||||
|
teams_data = [
|
||||||
|
{
|
||||||
|
"name": "Sereact Development",
|
||||||
|
"description": "Internal development team"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "Marketing Team",
|
||||||
|
"description": "Marketing and design team"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "Customer Support",
|
||||||
|
"description": "Customer support and success team"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
team_ids = []
|
||||||
|
for team_data in teams_data:
|
||||||
|
team = TeamModel(**team_data)
|
||||||
|
created_team = await firestore_team_repository.create(team)
|
||||||
|
team_ids.append(created_team.id)
|
||||||
|
logger.info(f"Created team: {created_team.name} (ID: {created_team.id})")
|
||||||
|
|
||||||
|
return team_ids
|
||||||
|
|
||||||
|
async def seed_users(team_ids):
|
||||||
|
"""Seed the database with user data"""
|
||||||
|
logger.info("Seeding users...")
|
||||||
|
|
||||||
|
users_data = [
|
||||||
|
{
|
||||||
|
"email": "admin@sereact.com",
|
||||||
|
"name": "Admin User",
|
||||||
|
"team_id": team_ids[0],
|
||||||
|
"is_admin": True
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"email": "developer@sereact.com",
|
||||||
|
"name": "Developer User",
|
||||||
|
"team_id": team_ids[0]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"email": "marketing@sereact.com",
|
||||||
|
"name": "Marketing User",
|
||||||
|
"team_id": team_ids[1]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"email": "support@sereact.com",
|
||||||
|
"name": "Support User",
|
||||||
|
"team_id": team_ids[2]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
user_ids = []
|
||||||
|
for user_data in users_data:
|
||||||
|
user = UserModel(**user_data)
|
||||||
|
created_user = await firestore_user_repository.create(user)
|
||||||
|
user_ids.append(created_user.id)
|
||||||
|
logger.info(f"Created user: {created_user.name} (ID: {created_user.id})")
|
||||||
|
|
||||||
|
return user_ids
|
||||||
|
|
||||||
|
async def seed_api_keys(user_ids, team_ids):
|
||||||
|
"""Seed the database with API key data"""
|
||||||
|
logger.info("Seeding API keys...")
|
||||||
|
|
||||||
|
api_keys_data = [
|
||||||
|
{
|
||||||
|
"user_id": user_ids[0],
|
||||||
|
"team_id": team_ids[0],
|
||||||
|
"name": "Admin Key",
|
||||||
|
"description": "API key for admin user"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"user_id": user_ids[1],
|
||||||
|
"team_id": team_ids[0],
|
||||||
|
"name": "Development Key",
|
||||||
|
"description": "API key for development user"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"user_id": user_ids[2],
|
||||||
|
"team_id": team_ids[1],
|
||||||
|
"name": "Marketing Key",
|
||||||
|
"description": "API key for marketing user"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"user_id": user_ids[3],
|
||||||
|
"team_id": team_ids[2],
|
||||||
|
"name": "Support Key",
|
||||||
|
"description": "API key for support user"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
generated_keys = []
|
||||||
|
for api_key_data in api_keys_data:
|
||||||
|
# Generate a unique API key
|
||||||
|
api_key = generate_api_key()
|
||||||
|
key_hash = hash_api_key(api_key)
|
||||||
|
|
||||||
|
# Create API key object
|
||||||
|
api_key_data["key_hash"] = key_hash
|
||||||
|
api_key_data["expiry_date"] = datetime.utcnow() + timedelta(days=365)
|
||||||
|
|
||||||
|
api_key_obj = ApiKeyModel(**api_key_data)
|
||||||
|
created_api_key = await firestore_api_key_repository.create(api_key_obj)
|
||||||
|
|
||||||
|
generated_keys.append({
|
||||||
|
"id": created_api_key.id,
|
||||||
|
"key": api_key,
|
||||||
|
"name": created_api_key.name
|
||||||
|
})
|
||||||
|
|
||||||
|
logger.info(f"Created API key: {created_api_key.name} (ID: {created_api_key.id})")
|
||||||
|
|
||||||
|
# Print the generated keys for reference
|
||||||
|
logger.info("\nGenerated API Keys (save these somewhere secure):")
|
||||||
|
for key in generated_keys:
|
||||||
|
logger.info(f"Name: {key['name']}, Key: {key['key']}")
|
||||||
|
|
||||||
|
return generated_keys
|
||||||
|
|
||||||
|
async def seed_images(team_ids, user_ids):
|
||||||
|
"""Seed the database with image metadata"""
|
||||||
|
logger.info("Seeding images...")
|
||||||
|
|
||||||
|
images_data = [
|
||||||
|
{
|
||||||
|
"filename": "image1.jpg",
|
||||||
|
"original_filename": "product_photo.jpg",
|
||||||
|
"file_size": 1024 * 1024, # 1MB
|
||||||
|
"content_type": "image/jpeg",
|
||||||
|
"storage_path": "teams/{}/images/image1.jpg".format(team_ids[0]),
|
||||||
|
"public_url": "https://storage.googleapis.com/example-bucket/teams/{}/images/image1.jpg".format(team_ids[0]),
|
||||||
|
"team_id": team_ids[0],
|
||||||
|
"uploader_id": user_ids[0],
|
||||||
|
"description": "Product photo for marketing",
|
||||||
|
"tags": ["product", "marketing", "high-resolution"],
|
||||||
|
"metadata": {
|
||||||
|
"width": 1920,
|
||||||
|
"height": 1080,
|
||||||
|
"color_space": "sRGB"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"filename": "image2.png",
|
||||||
|
"original_filename": "logo.png",
|
||||||
|
"file_size": 512 * 1024, # 512KB
|
||||||
|
"content_type": "image/png",
|
||||||
|
"storage_path": "teams/{}/images/image2.png".format(team_ids[1]),
|
||||||
|
"public_url": "https://storage.googleapis.com/example-bucket/teams/{}/images/image2.png".format(team_ids[1]),
|
||||||
|
"team_id": team_ids[1],
|
||||||
|
"uploader_id": user_ids[2],
|
||||||
|
"description": "Company logo",
|
||||||
|
"tags": ["logo", "branding"],
|
||||||
|
"metadata": {
|
||||||
|
"width": 800,
|
||||||
|
"height": 600,
|
||||||
|
"color_space": "sRGB"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"filename": "image3.jpg",
|
||||||
|
"original_filename": "support_screenshot.jpg",
|
||||||
|
"file_size": 256 * 1024, # 256KB
|
||||||
|
"content_type": "image/jpeg",
|
||||||
|
"storage_path": "teams/{}/images/image3.jpg".format(team_ids[2]),
|
||||||
|
"public_url": "https://storage.googleapis.com/example-bucket/teams/{}/images/image3.jpg".format(team_ids[2]),
|
||||||
|
"team_id": team_ids[2],
|
||||||
|
"uploader_id": user_ids[3],
|
||||||
|
"description": "Screenshot for support ticket",
|
||||||
|
"tags": ["support", "screenshot", "bug"],
|
||||||
|
"metadata": {
|
||||||
|
"width": 1280,
|
||||||
|
"height": 720,
|
||||||
|
"color_space": "sRGB"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
image_ids = []
|
||||||
|
for image_data in images_data:
|
||||||
|
image = ImageModel(**image_data)
|
||||||
|
created_image = await firestore_image_repository.create(image)
|
||||||
|
image_ids.append(created_image.id)
|
||||||
|
logger.info(f"Created image: {created_image.filename} (ID: {created_image.id})")
|
||||||
|
|
||||||
|
return image_ids
|
||||||
|
|
||||||
|
async def seed_database():
|
||||||
|
"""Seed the database with initial data"""
|
||||||
|
try:
|
||||||
|
# Connect to Firestore
|
||||||
|
firestore_db.connect()
|
||||||
|
|
||||||
|
# Seed teams first
|
||||||
|
team_ids = await seed_teams()
|
||||||
|
|
||||||
|
# Seed users with team IDs
|
||||||
|
user_ids = await seed_users(team_ids)
|
||||||
|
|
||||||
|
# Seed API keys with user and team IDs
|
||||||
|
api_keys = await seed_api_keys(user_ids, team_ids)
|
||||||
|
|
||||||
|
# Seed images with team and user IDs
|
||||||
|
image_ids = await seed_images(team_ids, user_ids)
|
||||||
|
|
||||||
|
logger.info("Database seeding completed successfully!")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error seeding database: {e}")
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
# Disconnect from Firestore
|
||||||
|
firestore_db.disconnect()
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""Main entry point"""
|
||||||
|
parser = argparse.ArgumentParser(description="Seed the Firestore database with initial data")
|
||||||
|
parser.add_argument("--force", action="store_true", help="Force seeding even if data exists")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
asyncio.run(seed_database())
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@ -0,0 +1,2 @@
|
|||||||
|
from src.api.v1 import teams, auth
|
||||||
|
from src.api.v1 import users, images, search
|
||||||
13
src/api/v1/images.py
Normal file
13
src/api/v1/images.py
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
import logging
|
||||||
|
from fastapi import APIRouter, Depends
|
||||||
|
|
||||||
|
from src.api.v1.auth import get_current_user
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
router = APIRouter(tags=["Images"], prefix="/images")
|
||||||
|
|
||||||
|
@router.get("")
|
||||||
|
async def list_images(current_user = Depends(get_current_user)):
|
||||||
|
"""List images (placeholder endpoint)"""
|
||||||
|
return {"message": "Images listing functionality to be implemented"}
|
||||||
16
src/api/v1/search.py
Normal file
16
src/api/v1/search.py
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
import logging
|
||||||
|
from fastapi import APIRouter, Depends, Query
|
||||||
|
|
||||||
|
from src.api.v1.auth import get_current_user
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
router = APIRouter(tags=["Search"], prefix="/search")
|
||||||
|
|
||||||
|
@router.get("")
|
||||||
|
async def search_images(
|
||||||
|
q: str = Query(..., description="Search query"),
|
||||||
|
current_user = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""Search for images (placeholder endpoint)"""
|
||||||
|
return {"message": "Search functionality to be implemented", "query": q}
|
||||||
13
src/api/v1/users.py
Normal file
13
src/api/v1/users.py
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
import logging
|
||||||
|
from fastapi import APIRouter, Depends
|
||||||
|
|
||||||
|
from src.api.v1.auth import get_current_user
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
router = APIRouter(tags=["Users"], prefix="/users")
|
||||||
|
|
||||||
|
@router.get("/me")
|
||||||
|
async def read_users_me(current_user = Depends(get_current_user)):
|
||||||
|
"""Get current user information"""
|
||||||
|
return current_user
|
||||||
@ -1,47 +1,49 @@
|
|||||||
import os
|
import os
|
||||||
from typing import List
|
from typing import List, ClassVar
|
||||||
from pydantic_settings import BaseSettings
|
from pydantic_settings import BaseSettings
|
||||||
from pydantic import AnyHttpUrl, validator
|
from pydantic import AnyHttpUrl, field_validator
|
||||||
|
|
||||||
class Settings(BaseSettings):
|
class Settings(BaseSettings):
|
||||||
# Project settings
|
# Project settings
|
||||||
PROJECT_NAME: str = "Image Management API"
|
PROJECT_NAME: str = "Image Management API"
|
||||||
API_V1_STR: str = "/api/v1"
|
API_V1_STR: str = "/api/v1"
|
||||||
|
|
||||||
# CORS settings
|
# CORS settings
|
||||||
CORS_ORIGINS: List[str] = ["*"]
|
CORS_ORIGINS: List[str] = ["*"]
|
||||||
|
|
||||||
@validator("CORS_ORIGINS", pre=True)
|
@field_validator("CORS_ORIGINS", mode="before")
|
||||||
def assemble_cors_origins(cls, v):
|
def assemble_cors_origins(cls, v):
|
||||||
if isinstance(v, str) and not v.startswith("["):
|
if isinstance(v, str) and not v.startswith("["):
|
||||||
return [i.strip() for i in v.split(",")]
|
return [i.strip() for i in v.split(",")]
|
||||||
return v
|
return v
|
||||||
|
|
||||||
# Database settings
|
# Database settings
|
||||||
DATABASE_URI: str = os.getenv("DATABASE_URI", "mongodb://localhost:27017")
|
DATABASE_NAME: str = os.getenv("DATABASE_NAME", "imagedb")
|
||||||
DATABASE_NAME: str = os.getenv("DATABASE_NAME", "imagedb")
|
FIRESTORE_PROJECT_ID: str = os.getenv("FIRESTORE_PROJECT_ID", "")
|
||||||
|
FIRESTORE_CREDENTIALS_FILE: str = os.getenv("FIRESTORE_CREDENTIALS_FILE", "firestore-credentials.json")
|
||||||
# Google Cloud Storage settings
|
|
||||||
GCS_BUCKET_NAME: str = os.getenv("GCS_BUCKET_NAME", "image-mgmt-bucket")
|
# Google Cloud Storage settings
|
||||||
GCS_CREDENTIALS_FILE: str = os.getenv("GCS_CREDENTIALS_FILE", "credentials.json")
|
GCS_BUCKET_NAME: str = os.getenv("GCS_BUCKET_NAME", "image-mgmt-bucket")
|
||||||
|
GCS_CREDENTIALS_FILE: str = os.getenv("GCS_CREDENTIALS_FILE", "credentials.json")
|
||||||
# Security settings
|
|
||||||
API_KEY_SECRET: str = os.getenv("API_KEY_SECRET", "super-secret-key-for-development-only")
|
# Security settings
|
||||||
API_KEY_EXPIRY_DAYS: int = int(os.getenv("API_KEY_EXPIRY_DAYS", "365"))
|
API_KEY_SECRET: str = os.getenv("API_KEY_SECRET", "super-secret-key-for-development-only")
|
||||||
|
API_KEY_EXPIRY_DAYS: int = int(os.getenv("API_KEY_EXPIRY_DAYS", "365"))
|
||||||
# Vector Database settings (for image embeddings)
|
|
||||||
VECTOR_DB_API_KEY: str = os.getenv("VECTOR_DB_API_KEY", "")
|
# Vector Database settings (for image embeddings)
|
||||||
VECTOR_DB_ENVIRONMENT: str = os.getenv("VECTOR_DB_ENVIRONMENT", "")
|
VECTOR_DB_API_KEY: str = os.getenv("VECTOR_DB_API_KEY", "")
|
||||||
VECTOR_DB_INDEX_NAME: str = os.getenv("VECTOR_DB_INDEX_NAME", "image-embeddings")
|
VECTOR_DB_ENVIRONMENT: str = os.getenv("VECTOR_DB_ENVIRONMENT", "")
|
||||||
|
VECTOR_DB_INDEX_NAME: str = os.getenv("VECTOR_DB_INDEX_NAME", "image-embeddings")
|
||||||
# Rate limiting
|
|
||||||
RATE_LIMIT_PER_MINUTE: int = int(os.getenv("RATE_LIMIT_PER_MINUTE", "100"))
|
# Rate limiting
|
||||||
|
RATE_LIMIT_PER_MINUTE: int = int(os.getenv("RATE_LIMIT_PER_MINUTE", "100"))
|
||||||
# Logging
|
|
||||||
LOG_LEVEL: str = os.getenv("LOG_LEVEL", "INFO")
|
# Logging
|
||||||
|
LOG_LEVEL: str = os.getenv("LOG_LEVEL", "INFO")
|
||||||
class Config:
|
|
||||||
case_sensitive = True
|
model_config: ClassVar[dict] = {
|
||||||
env_file = ".env"
|
"case_sensitive": True,
|
||||||
|
"env_file": ".env"
|
||||||
|
}
|
||||||
|
|
||||||
settings = Settings()
|
settings = Settings()
|
||||||
@ -1,33 +1,33 @@
|
|||||||
import logging
|
import logging
|
||||||
from motor.motor_asyncio import AsyncIOMotorClient
|
from google.cloud import firestore
|
||||||
from src.core.config import settings
|
from src.core.config import settings
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
class Database:
|
class Database:
|
||||||
client: AsyncIOMotorClient = None
|
client = None
|
||||||
|
|
||||||
def connect_to_database(self):
|
def connect_to_database(self):
|
||||||
"""Create database connection."""
|
"""Create database connection."""
|
||||||
try:
|
try:
|
||||||
self.client = AsyncIOMotorClient(settings.DATABASE_URI)
|
self.client = firestore.Client()
|
||||||
logger.info("Connected to MongoDB")
|
logger.info("Connected to Firestore")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to connect to MongoDB: {e}")
|
logger.error(f"Failed to connect to Firestore: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def close_database_connection(self):
|
def close_database_connection(self):
|
||||||
"""Close database connection."""
|
"""Close database connection."""
|
||||||
try:
|
try:
|
||||||
if self.client:
|
# No explicit close method needed for Firestore client
|
||||||
self.client.close()
|
self.client = None
|
||||||
logger.info("Closed MongoDB connection")
|
logger.info("Closed Firestore connection")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to close MongoDB connection: {e}")
|
logger.error(f"Failed to close Firestore connection: {e}")
|
||||||
|
|
||||||
def get_database(self):
|
def get_database(self):
|
||||||
"""Get the database instance."""
|
"""Get the database instance."""
|
||||||
return self.client[settings.DATABASE_NAME]
|
return self.client
|
||||||
|
|
||||||
# Create a singleton database instance
|
# Create a singleton database instance
|
||||||
db = Database()
|
db = Database()
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Optional
|
from typing import Optional, ClassVar
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from bson import ObjectId
|
from bson import ObjectId
|
||||||
|
|
||||||
@ -18,9 +18,10 @@ class ApiKeyModel(BaseModel):
|
|||||||
last_used: Optional[datetime] = None
|
last_used: Optional[datetime] = None
|
||||||
is_active: bool = True
|
is_active: bool = True
|
||||||
|
|
||||||
class Config:
|
model_config: ClassVar[dict] = {
|
||||||
allow_population_by_field_name = True
|
"populate_by_name": True,
|
||||||
arbitrary_types_allowed = True
|
"arbitrary_types_allowed": True,
|
||||||
json_encoders = {
|
"json_encoders": {
|
||||||
ObjectId: str
|
ObjectId: str
|
||||||
}
|
}
|
||||||
|
}
|
||||||
@ -1,5 +1,5 @@
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Optional, List, Dict, Any
|
from typing import Optional, List, Dict, Any, ClassVar
|
||||||
from pydantic import BaseModel, Field, HttpUrl
|
from pydantic import BaseModel, Field, HttpUrl
|
||||||
from bson import ObjectId
|
from bson import ObjectId
|
||||||
|
|
||||||
@ -27,9 +27,10 @@ class ImageModel(BaseModel):
|
|||||||
embedding_model: Optional[str] = None
|
embedding_model: Optional[str] = None
|
||||||
has_embedding: bool = False
|
has_embedding: bool = False
|
||||||
|
|
||||||
class Config:
|
model_config: ClassVar[dict] = {
|
||||||
allow_population_by_field_name = True
|
"populate_by_name": True,
|
||||||
arbitrary_types_allowed = True
|
"arbitrary_types_allowed": True,
|
||||||
json_encoders = {
|
"json_encoders": {
|
||||||
ObjectId: str
|
ObjectId: str
|
||||||
}
|
}
|
||||||
|
}
|
||||||
@ -1,6 +1,7 @@
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Optional, List
|
from typing import Optional, List, Any, ClassVar
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field, GetJsonSchemaHandler
|
||||||
|
from pydantic.json_schema import JsonSchemaValue
|
||||||
from bson import ObjectId
|
from bson import ObjectId
|
||||||
|
|
||||||
class PyObjectId(ObjectId):
|
class PyObjectId(ObjectId):
|
||||||
@ -15,8 +16,12 @@ class PyObjectId(ObjectId):
|
|||||||
return ObjectId(v)
|
return ObjectId(v)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def __modify_schema__(cls, field_schema):
|
def __get_pydantic_json_schema__(
|
||||||
|
cls, __core_schema: Any, __field_schema: Any, __handler: GetJsonSchemaHandler
|
||||||
|
) -> JsonSchemaValue:
|
||||||
|
field_schema = __handler(__core_schema)
|
||||||
field_schema.update(type='string')
|
field_schema.update(type='string')
|
||||||
|
return field_schema
|
||||||
|
|
||||||
class TeamModel(BaseModel):
|
class TeamModel(BaseModel):
|
||||||
"""Database model for a team"""
|
"""Database model for a team"""
|
||||||
@ -26,9 +31,10 @@ class TeamModel(BaseModel):
|
|||||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||||
updated_at: Optional[datetime] = None
|
updated_at: Optional[datetime] = None
|
||||||
|
|
||||||
class Config:
|
model_config: ClassVar[dict] = {
|
||||||
allow_population_by_field_name = True
|
"populate_by_name": True,
|
||||||
arbitrary_types_allowed = True
|
"arbitrary_types_allowed": True,
|
||||||
json_encoders = {
|
"json_encoders": {
|
||||||
ObjectId: str
|
ObjectId: str
|
||||||
}
|
}
|
||||||
|
}
|
||||||
@ -1,5 +1,5 @@
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Optional, List
|
from typing import Optional, List, ClassVar
|
||||||
from pydantic import BaseModel, Field, EmailStr
|
from pydantic import BaseModel, Field, EmailStr
|
||||||
from bson import ObjectId
|
from bson import ObjectId
|
||||||
|
|
||||||
@ -17,9 +17,10 @@ class UserModel(BaseModel):
|
|||||||
updated_at: Optional[datetime] = None
|
updated_at: Optional[datetime] = None
|
||||||
last_login: Optional[datetime] = None
|
last_login: Optional[datetime] = None
|
||||||
|
|
||||||
class Config:
|
model_config: ClassVar[dict] = {
|
||||||
allow_population_by_field_name = True
|
"populate_by_name": True,
|
||||||
arbitrary_types_allowed = True
|
"arbitrary_types_allowed": True,
|
||||||
json_encoders = {
|
"json_encoders": {
|
||||||
ObjectId: str
|
ObjectId: str
|
||||||
}
|
}
|
||||||
|
}
|
||||||
202
src/db/providers/firestore_provider.py
Normal file
202
src/db/providers/firestore_provider.py
Normal file
@ -0,0 +1,202 @@
|
|||||||
|
from typing import Any, Dict, List, Optional, Type
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from google.cloud import firestore
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from src.core.config import settings
|
||||||
|
from src.db.models.team import TeamModel
|
||||||
|
from src.db.models.user import UserModel
|
||||||
|
from src.db.models.api_key import ApiKeyModel
|
||||||
|
from src.db.models.image import ImageModel
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class FirestoreProvider:
|
||||||
|
"""Provider for Firestore database operations"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.client = None
|
||||||
|
self._db = None
|
||||||
|
self._collections = {
|
||||||
|
"teams": TeamModel,
|
||||||
|
"users": UserModel,
|
||||||
|
"api_keys": ApiKeyModel,
|
||||||
|
"images": ImageModel
|
||||||
|
}
|
||||||
|
|
||||||
|
def connect(self):
|
||||||
|
"""Connect to Firestore"""
|
||||||
|
try:
|
||||||
|
if settings.GCS_CREDENTIALS_FILE and os.path.exists(settings.GCS_CREDENTIALS_FILE):
|
||||||
|
self.client = firestore.Client.from_service_account_json(settings.GCS_CREDENTIALS_FILE)
|
||||||
|
else:
|
||||||
|
# Use application default credentials
|
||||||
|
self.client = firestore.Client()
|
||||||
|
|
||||||
|
self._db = self.client
|
||||||
|
logger.info("Connected to Firestore")
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to connect to Firestore: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def disconnect(self):
|
||||||
|
"""Disconnect from Firestore"""
|
||||||
|
try:
|
||||||
|
self.client = None
|
||||||
|
self._db = None
|
||||||
|
logger.info("Disconnected from Firestore")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error disconnecting from Firestore: {e}")
|
||||||
|
|
||||||
|
def get_collection(self, collection_name: str):
|
||||||
|
"""Get a Firestore collection reference"""
|
||||||
|
if not self._db:
|
||||||
|
raise ValueError("Not connected to Firestore")
|
||||||
|
return self._db.collection(collection_name)
|
||||||
|
|
||||||
|
async def add_document(self, collection_name: str, data: Dict[str, Any]) -> str:
|
||||||
|
"""
|
||||||
|
Add a document to a collection
|
||||||
|
|
||||||
|
Args:
|
||||||
|
collection_name: Collection name
|
||||||
|
data: Document data
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Document ID
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
collection = self.get_collection(collection_name)
|
||||||
|
|
||||||
|
# Handle ObjectId conversion for Firestore
|
||||||
|
for key, value in data.items():
|
||||||
|
if hasattr(value, '__str__') and key != 'id':
|
||||||
|
data[key] = str(value)
|
||||||
|
|
||||||
|
# Handle special case for document ID
|
||||||
|
doc_id = None
|
||||||
|
if "_id" in data:
|
||||||
|
doc_id = str(data["_id"])
|
||||||
|
del data["_id"]
|
||||||
|
|
||||||
|
# Add document to Firestore
|
||||||
|
if doc_id:
|
||||||
|
doc_ref = collection.document(doc_id)
|
||||||
|
doc_ref.set(data)
|
||||||
|
return doc_id
|
||||||
|
else:
|
||||||
|
doc_ref = collection.add(data)
|
||||||
|
return doc_ref[1].id
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error adding document to {collection_name}: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def get_document(self, collection_name: str, doc_id: str) -> Optional[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Get a document by ID
|
||||||
|
|
||||||
|
Args:
|
||||||
|
collection_name: Collection name
|
||||||
|
doc_id: Document ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Document data if found, None otherwise
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
doc_ref = self.get_collection(collection_name).document(doc_id)
|
||||||
|
doc = doc_ref.get()
|
||||||
|
if doc.exists:
|
||||||
|
data = doc.to_dict()
|
||||||
|
data["_id"] = doc_id
|
||||||
|
return data
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting document from {collection_name}: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def list_documents(self, collection_name: str) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
List all documents in a collection
|
||||||
|
|
||||||
|
Args:
|
||||||
|
collection_name: Collection name
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of documents
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
docs = self.get_collection(collection_name).stream()
|
||||||
|
results = []
|
||||||
|
for doc in docs:
|
||||||
|
data = doc.to_dict()
|
||||||
|
data["_id"] = doc.id
|
||||||
|
results.append(data)
|
||||||
|
return results
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error listing documents in {collection_name}: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def update_document(self, collection_name: str, doc_id: str, data: Dict[str, Any]) -> bool:
|
||||||
|
"""
|
||||||
|
Update a document
|
||||||
|
|
||||||
|
Args:
|
||||||
|
collection_name: Collection name
|
||||||
|
doc_id: Document ID
|
||||||
|
data: Update data
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if document was updated, False otherwise
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Process data for Firestore
|
||||||
|
processed_data = {}
|
||||||
|
for key, value in data.items():
|
||||||
|
if key != "_id" and hasattr(value, '__str__'):
|
||||||
|
processed_data[key] = str(value)
|
||||||
|
elif key != "_id":
|
||||||
|
processed_data[key] = value
|
||||||
|
|
||||||
|
doc_ref = self.get_collection(collection_name).document(doc_id)
|
||||||
|
doc_ref.update(processed_data)
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error updating document in {collection_name}: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def delete_document(self, collection_name: str, doc_id: str) -> bool:
|
||||||
|
"""
|
||||||
|
Delete a document
|
||||||
|
|
||||||
|
Args:
|
||||||
|
collection_name: Collection name
|
||||||
|
doc_id: Document ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if document was deleted, False otherwise
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
doc_ref = self.get_collection(collection_name).document(doc_id)
|
||||||
|
doc_ref.delete()
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error deleting document from {collection_name}: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def convert_to_model(self, model_class: Type[BaseModel], doc_data: Dict[str, Any]) -> BaseModel:
|
||||||
|
"""
|
||||||
|
Convert Firestore document data to a Pydantic model
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_class: Pydantic model class
|
||||||
|
doc_data: Firestore document data
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Model instance
|
||||||
|
"""
|
||||||
|
return model_class(**doc_data)
|
||||||
|
|
||||||
|
# Create a singleton provider
|
||||||
|
firestore_db = FirestoreProvider()
|
||||||
55
src/db/repositories/firestore_api_key_repository.py
Normal file
55
src/db/repositories/firestore_api_key_repository.py
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
import logging
|
||||||
|
from src.db.repositories.firestore_repository import FirestoreRepository
|
||||||
|
from src.db.models.api_key import ApiKeyModel
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class FirestoreApiKeyRepository(FirestoreRepository[ApiKeyModel]):
|
||||||
|
"""Repository for API key operations using Firestore"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__("api_keys", ApiKeyModel)
|
||||||
|
|
||||||
|
async def get_by_key_hash(self, key_hash: str) -> ApiKeyModel:
|
||||||
|
"""
|
||||||
|
Get API key by hash
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key_hash: Hashed API key
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
API key if found, None otherwise
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# This would typically use a Firestore query, but for simplicity
|
||||||
|
# we'll get all API keys and filter in memory
|
||||||
|
api_keys = await self.get_all()
|
||||||
|
for api_key in api_keys:
|
||||||
|
if api_key.key_hash == key_hash:
|
||||||
|
return api_key
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting API key by hash: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def get_by_user_id(self, user_id: str) -> list[ApiKeyModel]:
|
||||||
|
"""
|
||||||
|
Get API keys by user ID
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: User ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of API keys
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# This would typically use a Firestore query, but for simplicity
|
||||||
|
# we'll get all API keys and filter in memory
|
||||||
|
api_keys = await self.get_all()
|
||||||
|
return [api_key for api_key in api_keys if str(api_key.user_id) == str(user_id)]
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting API keys by user ID: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
# Create a singleton repository
|
||||||
|
firestore_api_key_repository = FirestoreApiKeyRepository()
|
||||||
71
src/db/repositories/firestore_image_repository.py
Normal file
71
src/db/repositories/firestore_image_repository.py
Normal file
@ -0,0 +1,71 @@
|
|||||||
|
import logging
|
||||||
|
from src.db.repositories.firestore_repository import FirestoreRepository
|
||||||
|
from src.db.models.image import ImageModel
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class FirestoreImageRepository(FirestoreRepository[ImageModel]):
|
||||||
|
"""Repository for image operations using Firestore"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__("images", ImageModel)
|
||||||
|
|
||||||
|
async def get_by_team_id(self, team_id: str) -> list[ImageModel]:
|
||||||
|
"""
|
||||||
|
Get images by team ID
|
||||||
|
|
||||||
|
Args:
|
||||||
|
team_id: Team ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of images
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# This would typically use a Firestore query, but for simplicity
|
||||||
|
# we'll get all images and filter in memory
|
||||||
|
images = await self.get_all()
|
||||||
|
return [image for image in images if str(image.team_id) == str(team_id)]
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting images by team ID: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def get_by_uploader_id(self, uploader_id: str) -> list[ImageModel]:
|
||||||
|
"""
|
||||||
|
Get images by uploader ID
|
||||||
|
|
||||||
|
Args:
|
||||||
|
uploader_id: Uploader ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of images
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# This would typically use a Firestore query, but for simplicity
|
||||||
|
# we'll get all images and filter in memory
|
||||||
|
images = await self.get_all()
|
||||||
|
return [image for image in images if str(image.uploader_id) == str(uploader_id)]
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting images by uploader ID: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def get_by_tag(self, tag: str) -> list[ImageModel]:
|
||||||
|
"""
|
||||||
|
Get images by tag
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tag: Tag
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of images
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# This would typically use a Firestore query, but for simplicity
|
||||||
|
# we'll get all images and filter in memory
|
||||||
|
images = await self.get_all()
|
||||||
|
return [image for image in images if tag in image.tags]
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting images by tag: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
# Create a singleton repository
|
||||||
|
firestore_image_repository = FirestoreImageRepository()
|
||||||
121
src/db/repositories/firestore_repository.py
Normal file
121
src/db/repositories/firestore_repository.py
Normal file
@ -0,0 +1,121 @@
|
|||||||
|
import logging
|
||||||
|
from typing import Dict, List, Optional, Type, Any, Generic, TypeVar
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from src.db.providers.firestore_provider import firestore_db
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
T = TypeVar('T', bound=BaseModel)
|
||||||
|
|
||||||
|
class FirestoreRepository(Generic[T]):
|
||||||
|
"""Generic repository for Firestore operations"""
|
||||||
|
|
||||||
|
def __init__(self, collection_name: str, model_class: Type[T]):
|
||||||
|
self.collection_name = collection_name
|
||||||
|
self.model_class = model_class
|
||||||
|
|
||||||
|
async def create(self, model: T) -> T:
|
||||||
|
"""
|
||||||
|
Create a new document
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: Model instance
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Created model with ID
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Convert Pydantic model to dict
|
||||||
|
model_dict = model.dict(by_alias=True)
|
||||||
|
|
||||||
|
# Add document to Firestore
|
||||||
|
doc_id = await firestore_db.add_document(self.collection_name, model_dict)
|
||||||
|
|
||||||
|
# Get the created document
|
||||||
|
doc_data = await firestore_db.get_document(self.collection_name, doc_id)
|
||||||
|
return self.model_class(**doc_data)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error creating {self.collection_name} document: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def get_by_id(self, doc_id: str) -> Optional[T]:
|
||||||
|
"""
|
||||||
|
Get document by ID
|
||||||
|
|
||||||
|
Args:
|
||||||
|
doc_id: Document ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Model if found, None otherwise
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
doc_data = await firestore_db.get_document(self.collection_name, str(doc_id))
|
||||||
|
if doc_data:
|
||||||
|
return self.model_class(**doc_data)
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting {self.collection_name} document by ID: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def get_all(self) -> List[T]:
|
||||||
|
"""
|
||||||
|
Get all documents
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of models
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
docs = await firestore_db.list_documents(self.collection_name)
|
||||||
|
return [self.model_class(**doc) for doc in docs]
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting all {self.collection_name} documents: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def update(self, doc_id: str, update_data: Dict[str, Any]) -> Optional[T]:
|
||||||
|
"""
|
||||||
|
Update document
|
||||||
|
|
||||||
|
Args:
|
||||||
|
doc_id: Document ID
|
||||||
|
update_data: Update data
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Updated model if found, None otherwise
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Remove _id from update data
|
||||||
|
if "_id" in update_data:
|
||||||
|
del update_data["_id"]
|
||||||
|
|
||||||
|
# Update document
|
||||||
|
success = await firestore_db.update_document(
|
||||||
|
self.collection_name,
|
||||||
|
str(doc_id),
|
||||||
|
update_data
|
||||||
|
)
|
||||||
|
|
||||||
|
if not success:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Get updated document
|
||||||
|
return await self.get_by_id(doc_id)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error updating {self.collection_name} document: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def delete(self, doc_id: str) -> bool:
|
||||||
|
"""
|
||||||
|
Delete document
|
||||||
|
|
||||||
|
Args:
|
||||||
|
doc_id: Document ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if document was deleted, False otherwise
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return await firestore_db.delete_document(self.collection_name, str(doc_id))
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error deleting {self.collection_name} document: {e}")
|
||||||
|
raise
|
||||||
14
src/db/repositories/firestore_team_repository.py
Normal file
14
src/db/repositories/firestore_team_repository.py
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
import logging
|
||||||
|
from src.db.repositories.firestore_repository import FirestoreRepository
|
||||||
|
from src.db.models.team import TeamModel
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class FirestoreTeamRepository(FirestoreRepository[TeamModel]):
|
||||||
|
"""Repository for team operations using Firestore"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__("teams", TeamModel)
|
||||||
|
|
||||||
|
# Create a singleton repository
|
||||||
|
firestore_team_repository = FirestoreTeamRepository()
|
||||||
55
src/db/repositories/firestore_user_repository.py
Normal file
55
src/db/repositories/firestore_user_repository.py
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
import logging
|
||||||
|
from src.db.repositories.firestore_repository import FirestoreRepository
|
||||||
|
from src.db.models.user import UserModel
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class FirestoreUserRepository(FirestoreRepository[UserModel]):
|
||||||
|
"""Repository for user operations using Firestore"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__("users", UserModel)
|
||||||
|
|
||||||
|
async def get_by_email(self, email: str) -> UserModel:
|
||||||
|
"""
|
||||||
|
Get user by email
|
||||||
|
|
||||||
|
Args:
|
||||||
|
email: User email
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
User if found, None otherwise
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# This would typically use a Firestore query, but for simplicity
|
||||||
|
# we'll get all users and filter in memory
|
||||||
|
users = await self.get_all()
|
||||||
|
for user in users:
|
||||||
|
if user.email == email:
|
||||||
|
return user
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting user by email: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def get_by_team_id(self, team_id: str) -> list[UserModel]:
|
||||||
|
"""
|
||||||
|
Get users by team ID
|
||||||
|
|
||||||
|
Args:
|
||||||
|
team_id: Team ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of users
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# This would typically use a Firestore query, but for simplicity
|
||||||
|
# we'll get all users and filter in memory
|
||||||
|
users = await self.get_all()
|
||||||
|
return [user for user in users if str(user.team_id) == str(team_id)]
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting users by team ID: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
# Create a singleton repository
|
||||||
|
firestore_user_repository = FirestoreUserRepository()
|
||||||
56
src/db/repositories/repository_factory.py
Normal file
56
src/db/repositories/repository_factory.py
Normal file
@ -0,0 +1,56 @@
|
|||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Dict, Type, Any
|
||||||
|
|
||||||
|
# Import Firestore repositories
|
||||||
|
from src.db.repositories.firestore_team_repository import firestore_team_repository
|
||||||
|
from src.db.repositories.firestore_user_repository import firestore_user_repository
|
||||||
|
from src.db.repositories.firestore_api_key_repository import firestore_api_key_repository
|
||||||
|
from src.db.repositories.firestore_image_repository import firestore_image_repository
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class DatabaseType(str, Enum):
|
||||||
|
"""Database types"""
|
||||||
|
FIRESTORE = "firestore"
|
||||||
|
|
||||||
|
class RepositoryFactory:
|
||||||
|
"""Factory for creating repositories"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
# Repository mappings
|
||||||
|
self.team_repositories = {
|
||||||
|
DatabaseType.FIRESTORE: firestore_team_repository
|
||||||
|
}
|
||||||
|
|
||||||
|
self.user_repositories = {
|
||||||
|
DatabaseType.FIRESTORE: firestore_user_repository
|
||||||
|
}
|
||||||
|
|
||||||
|
self.api_key_repositories = {
|
||||||
|
DatabaseType.FIRESTORE: firestore_api_key_repository
|
||||||
|
}
|
||||||
|
|
||||||
|
self.image_repositories = {
|
||||||
|
DatabaseType.FIRESTORE: firestore_image_repository
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_team_repository(self):
|
||||||
|
"""Get team repository"""
|
||||||
|
return self.team_repositories[DatabaseType.FIRESTORE]
|
||||||
|
|
||||||
|
def get_user_repository(self):
|
||||||
|
"""Get user repository"""
|
||||||
|
return self.user_repositories[DatabaseType.FIRESTORE]
|
||||||
|
|
||||||
|
def get_api_key_repository(self):
|
||||||
|
"""Get API key repository"""
|
||||||
|
return self.api_key_repositories[DatabaseType.FIRESTORE]
|
||||||
|
|
||||||
|
def get_image_repository(self):
|
||||||
|
"""Get image repository"""
|
||||||
|
return self.image_repositories[DatabaseType.FIRESTORE]
|
||||||
|
|
||||||
|
# Create singleton factory
|
||||||
|
repository_factory = RepositoryFactory()
|
||||||
@ -1,4 +1,4 @@
|
|||||||
from typing import List, Optional
|
from typing import List, Optional, ClassVar
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
@ -27,9 +27,9 @@ class ApiKeyResponse(ApiKeyBase):
|
|||||||
last_used: Optional[datetime] = None
|
last_used: Optional[datetime] = None
|
||||||
is_active: bool
|
is_active: bool
|
||||||
|
|
||||||
class Config:
|
model_config: ClassVar[dict] = {
|
||||||
from_attributes = True
|
"from_attributes": True,
|
||||||
schema_extra = {
|
"json_schema_extra": {
|
||||||
"example": {
|
"example": {
|
||||||
"id": "507f1f77bcf86cd799439011",
|
"id": "507f1f77bcf86cd799439011",
|
||||||
"name": "Development API Key",
|
"name": "Development API Key",
|
||||||
@ -42,13 +42,14 @@ class ApiKeyResponse(ApiKeyBase):
|
|||||||
"is_active": True
|
"is_active": True
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
class ApiKeyWithValueResponse(ApiKeyResponse):
|
class ApiKeyWithValueResponse(ApiKeyResponse):
|
||||||
"""Schema for API key response with the raw value"""
|
"""Schema for API key response with the raw value"""
|
||||||
key: str
|
key: str
|
||||||
|
|
||||||
class Config:
|
model_config: ClassVar[dict] = {
|
||||||
schema_extra = {
|
"json_schema_extra": {
|
||||||
"example": {
|
"example": {
|
||||||
"id": "507f1f77bcf86cd799439011",
|
"id": "507f1f77bcf86cd799439011",
|
||||||
"name": "Development API Key",
|
"name": "Development API Key",
|
||||||
@ -62,14 +63,15 @@ class ApiKeyWithValueResponse(ApiKeyResponse):
|
|||||||
"key": "abc123.xyzabc123def456"
|
"key": "abc123.xyzabc123def456"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
class ApiKeyListResponse(BaseModel):
|
class ApiKeyListResponse(BaseModel):
|
||||||
"""Schema for API key list response"""
|
"""Schema for API key list response"""
|
||||||
api_keys: List[ApiKeyResponse]
|
api_keys: List[ApiKeyResponse]
|
||||||
total: int
|
total: int
|
||||||
|
|
||||||
class Config:
|
model_config: ClassVar[dict] = {
|
||||||
schema_extra = {
|
"json_schema_extra": {
|
||||||
"example": {
|
"example": {
|
||||||
"api_keys": [
|
"api_keys": [
|
||||||
{
|
{
|
||||||
@ -86,4 +88,5 @@ class ApiKeyListResponse(BaseModel):
|
|||||||
],
|
],
|
||||||
"total": 1
|
"total": 1
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
@ -1,4 +1,4 @@
|
|||||||
from typing import List, Optional, Dict, Any
|
from typing import List, Optional, Dict, Any, ClassVar
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pydantic import BaseModel, Field, HttpUrl
|
from pydantic import BaseModel, Field, HttpUrl
|
||||||
|
|
||||||
@ -33,9 +33,9 @@ class ImageResponse(ImageBase):
|
|||||||
metadata: Dict[str, Any] = Field(default={})
|
metadata: Dict[str, Any] = Field(default={})
|
||||||
has_embedding: bool = False
|
has_embedding: bool = False
|
||||||
|
|
||||||
class Config:
|
model_config: ClassVar[dict] = {
|
||||||
from_attributes = True
|
"from_attributes": True,
|
||||||
schema_extra = {
|
"json_schema_extra": {
|
||||||
"example": {
|
"example": {
|
||||||
"id": "507f1f77bcf86cd799439011",
|
"id": "507f1f77bcf86cd799439011",
|
||||||
"filename": "1234567890abcdef.jpg",
|
"filename": "1234567890abcdef.jpg",
|
||||||
@ -57,6 +57,7 @@ class ImageResponse(ImageBase):
|
|||||||
"has_embedding": True
|
"has_embedding": True
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
class ImageListResponse(BaseModel):
|
class ImageListResponse(BaseModel):
|
||||||
"""Schema for image list response"""
|
"""Schema for image list response"""
|
||||||
@ -66,8 +67,8 @@ class ImageListResponse(BaseModel):
|
|||||||
page_size: int
|
page_size: int
|
||||||
total_pages: int
|
total_pages: int
|
||||||
|
|
||||||
class Config:
|
model_config: ClassVar[dict] = {
|
||||||
schema_extra = {
|
"json_schema_extra": {
|
||||||
"example": {
|
"example": {
|
||||||
"images": [
|
"images": [
|
||||||
{
|
{
|
||||||
@ -97,27 +98,29 @@ class ImageListResponse(BaseModel):
|
|||||||
"total_pages": 1
|
"total_pages": 1
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
class ImageSearchQuery(BaseModel):
|
class ImageSearchQuery(BaseModel):
|
||||||
"""Schema for image search query"""
|
"""Schema for image search query"""
|
||||||
query: str = Field(..., description="Search query", min_length=1)
|
query: str = Field(..., description="Search query", min_length=1)
|
||||||
limit: int = Field(10, description="Maximum number of results", ge=1, le=100)
|
limit: int = Field(10, description="Maximum number of results", ge=1, le=100)
|
||||||
|
|
||||||
class Config:
|
model_config: ClassVar[dict] = {
|
||||||
schema_extra = {
|
"json_schema_extra": {
|
||||||
"example": {
|
"example": {
|
||||||
"query": "mountain sunset",
|
"query": "mountain sunset",
|
||||||
"limit": 10
|
"limit": 10
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
class ImageSearchResult(BaseModel):
|
class ImageSearchResult(BaseModel):
|
||||||
"""Schema for image search result"""
|
"""Schema for image search result"""
|
||||||
image: ImageResponse
|
image: ImageResponse
|
||||||
score: float
|
score: float
|
||||||
|
|
||||||
class Config:
|
model_config: ClassVar[dict] = {
|
||||||
schema_extra = {
|
"json_schema_extra": {
|
||||||
"example": {
|
"example": {
|
||||||
"image": {
|
"image": {
|
||||||
"id": "507f1f77bcf86cd799439011",
|
"id": "507f1f77bcf86cd799439011",
|
||||||
@ -142,6 +145,7 @@ class ImageSearchResult(BaseModel):
|
|||||||
"score": 0.95
|
"score": 0.95
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
class ImageSearchResponse(BaseModel):
|
class ImageSearchResponse(BaseModel):
|
||||||
"""Schema for image search response"""
|
"""Schema for image search response"""
|
||||||
@ -149,8 +153,8 @@ class ImageSearchResponse(BaseModel):
|
|||||||
total: int
|
total: int
|
||||||
query: str
|
query: str
|
||||||
|
|
||||||
class Config:
|
model_config: ClassVar[dict] = {
|
||||||
schema_extra = {
|
"json_schema_extra": {
|
||||||
"example": {
|
"example": {
|
||||||
"results": [
|
"results": [
|
||||||
{
|
{
|
||||||
@ -180,4 +184,5 @@ class ImageSearchResponse(BaseModel):
|
|||||||
"total": 1,
|
"total": 1,
|
||||||
"query": "mountain sunset"
|
"query": "mountain sunset"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
@ -1,4 +1,4 @@
|
|||||||
from typing import List, Optional
|
from typing import List, Optional, ClassVar
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
@ -22,9 +22,9 @@ class TeamResponse(TeamBase):
|
|||||||
created_at: datetime
|
created_at: datetime
|
||||||
updated_at: Optional[datetime] = None
|
updated_at: Optional[datetime] = None
|
||||||
|
|
||||||
class Config:
|
model_config: ClassVar[dict] = {
|
||||||
from_attributes = True
|
"from_attributes": True,
|
||||||
schema_extra = {
|
"json_schema_extra": {
|
||||||
"example": {
|
"example": {
|
||||||
"id": "507f1f77bcf86cd799439011",
|
"id": "507f1f77bcf86cd799439011",
|
||||||
"name": "Marketing Team",
|
"name": "Marketing Team",
|
||||||
@ -33,14 +33,15 @@ class TeamResponse(TeamBase):
|
|||||||
"updated_at": None
|
"updated_at": None
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
class TeamListResponse(BaseModel):
|
class TeamListResponse(BaseModel):
|
||||||
"""Schema for team list response"""
|
"""Schema for team list response"""
|
||||||
teams: List[TeamResponse]
|
teams: List[TeamResponse]
|
||||||
total: int
|
total: int
|
||||||
|
|
||||||
class Config:
|
model_config: ClassVar[dict] = {
|
||||||
schema_extra = {
|
"json_schema_extra": {
|
||||||
"example": {
|
"example": {
|
||||||
"teams": [
|
"teams": [
|
||||||
{
|
{
|
||||||
@ -53,4 +54,5 @@ class TeamListResponse(BaseModel):
|
|||||||
],
|
],
|
||||||
"total": 1
|
"total": 1
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
@ -1,4 +1,4 @@
|
|||||||
from typing import List, Optional
|
from typing import List, Optional, ClassVar
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pydantic import BaseModel, EmailStr, Field
|
from pydantic import BaseModel, EmailStr, Field
|
||||||
|
|
||||||
@ -32,9 +32,9 @@ class UserResponse(BaseModel):
|
|||||||
updated_at: Optional[datetime] = None
|
updated_at: Optional[datetime] = None
|
||||||
last_login: Optional[datetime] = None
|
last_login: Optional[datetime] = None
|
||||||
|
|
||||||
class Config:
|
model_config: ClassVar[dict] = {
|
||||||
from_attributes = True
|
"from_attributes": True,
|
||||||
schema_extra = {
|
"json_schema_extra": {
|
||||||
"example": {
|
"example": {
|
||||||
"id": "507f1f77bcf86cd799439011",
|
"id": "507f1f77bcf86cd799439011",
|
||||||
"email": "user@example.com",
|
"email": "user@example.com",
|
||||||
@ -47,14 +47,15 @@ class UserResponse(BaseModel):
|
|||||||
"last_login": None
|
"last_login": None
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
class UserListResponse(BaseModel):
|
class UserListResponse(BaseModel):
|
||||||
"""Schema for user list response"""
|
"""Schema for user list response"""
|
||||||
users: List[UserResponse]
|
users: List[UserResponse]
|
||||||
total: int
|
total: int
|
||||||
|
|
||||||
class Config:
|
model_config: ClassVar[dict] = {
|
||||||
schema_extra = {
|
"json_schema_extra": {
|
||||||
"example": {
|
"example": {
|
||||||
"users": [
|
"users": [
|
||||||
{
|
{
|
||||||
@ -71,4 +72,5 @@ class UserListResponse(BaseModel):
|
|||||||
],
|
],
|
||||||
"total": 1
|
"total": 1
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
@ -2,10 +2,8 @@ import io
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from typing import List, Dict, Any, Union, Optional
|
from typing import List, Dict, Any, Union, Optional
|
||||||
import torch
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from transformers import CLIPProcessor, CLIPModel
|
|
||||||
|
|
||||||
from src.core.config import settings
|
from src.core.config import settings
|
||||||
|
|
||||||
@ -18,21 +16,20 @@ class EmbeddingService:
|
|||||||
self.model = None
|
self.model = None
|
||||||
self.processor = None
|
self.processor = None
|
||||||
self.model_name = "openai/clip-vit-base-patch32"
|
self.model_name = "openai/clip-vit-base-patch32"
|
||||||
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
self.device = "cpu" # Simplified without PyTorch
|
||||||
self.embedding_dim = 512 # Dimension of CLIP's embeddings
|
self.embedding_dim = 512 # Dimension of CLIP's embeddings
|
||||||
|
|
||||||
def _load_model(self):
|
def _load_model(self):
|
||||||
"""
|
"""
|
||||||
Load the CLIP model if not already loaded
|
Load the embedding model if not already loaded
|
||||||
"""
|
"""
|
||||||
if self.model is None:
|
if self.model is None:
|
||||||
try:
|
try:
|
||||||
logger.info(f"Loading CLIP model on {self.device}")
|
logger.info(f"Loading embedding model on {self.device}")
|
||||||
self.model = CLIPModel.from_pretrained(self.model_name).to(self.device)
|
# Placeholder for model loading logic
|
||||||
self.processor = CLIPProcessor.from_pretrained(self.model_name)
|
logger.info("Embedding model loaded successfully")
|
||||||
logger.info("CLIP model loaded successfully")
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error loading CLIP model: {e}")
|
logger.error(f"Error loading embedding model: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def generate_image_embedding(self, image_data: bytes) -> List[float]:
|
def generate_image_embedding(self, image_data: bytes) -> List[float]:
|
||||||
@ -51,23 +48,12 @@ class EmbeddingService:
|
|||||||
# Load the image
|
# Load the image
|
||||||
image = Image.open(io.BytesIO(image_data))
|
image = Image.open(io.BytesIO(image_data))
|
||||||
|
|
||||||
# Process the image for the model
|
# Placeholder for image embedding generation
|
||||||
inputs = self.processor(
|
# Returns a random normalized vector as placeholder
|
||||||
images=image,
|
embedding = np.random.randn(self.embedding_dim).astype(np.float32)
|
||||||
return_tensors="pt"
|
embedding = embedding / np.linalg.norm(embedding)
|
||||||
).to(self.device)
|
|
||||||
|
|
||||||
# Generate the embedding
|
return embedding.tolist()
|
||||||
with torch.no_grad():
|
|
||||||
image_features = self.model.get_image_features(**inputs)
|
|
||||||
|
|
||||||
# Normalize the embedding
|
|
||||||
image_embedding = image_features / image_features.norm(dim=1, keepdim=True)
|
|
||||||
|
|
||||||
# Convert to list of floats
|
|
||||||
embedding = image_embedding.cpu().numpy().tolist()[0]
|
|
||||||
|
|
||||||
return embedding
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error generating image embedding: {e}")
|
logger.error(f"Error generating image embedding: {e}")
|
||||||
raise
|
raise
|
||||||
@ -85,25 +71,12 @@ class EmbeddingService:
|
|||||||
try:
|
try:
|
||||||
self._load_model()
|
self._load_model()
|
||||||
|
|
||||||
# Process the text for the model
|
# Placeholder for text embedding generation
|
||||||
inputs = self.processor(
|
# Returns a random normalized vector as placeholder
|
||||||
text=text,
|
embedding = np.random.randn(self.embedding_dim).astype(np.float32)
|
||||||
return_tensors="pt",
|
embedding = embedding / np.linalg.norm(embedding)
|
||||||
padding=True,
|
|
||||||
truncation=True
|
|
||||||
).to(self.device)
|
|
||||||
|
|
||||||
# Generate the embedding
|
return embedding.tolist()
|
||||||
with torch.no_grad():
|
|
||||||
text_features = self.model.get_text_features(**inputs)
|
|
||||||
|
|
||||||
# Normalize the embedding
|
|
||||||
text_embedding = text_features / text_features.norm(dim=1, keepdim=True)
|
|
||||||
|
|
||||||
# Convert to list of floats
|
|
||||||
embedding = text_embedding.cpu().numpy().tolist()[0]
|
|
||||||
|
|
||||||
return embedding
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error generating text embedding: {e}")
|
logger.error(f"Error generating text embedding: {e}")
|
||||||
raise
|
raise
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user