init
This commit is contained in:
parent
fd47e008a8
commit
d6dea7ef17
51
.gitignore
vendored
Normal file
51
.gitignore
vendored
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
# Python
|
||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*$py.class
|
||||||
|
*.so
|
||||||
|
.Python
|
||||||
|
build/
|
||||||
|
develop-eggs/
|
||||||
|
dist/
|
||||||
|
downloads/
|
||||||
|
eggs/
|
||||||
|
.eggs/
|
||||||
|
lib/
|
||||||
|
lib64/
|
||||||
|
parts/
|
||||||
|
sdist/
|
||||||
|
var/
|
||||||
|
wheels/
|
||||||
|
*.egg-info/
|
||||||
|
.installed.cfg
|
||||||
|
*.egg
|
||||||
|
|
||||||
|
# Virtual Environment
|
||||||
|
venv/
|
||||||
|
env/
|
||||||
|
ENV/
|
||||||
|
|
||||||
|
# IDE
|
||||||
|
.idea/
|
||||||
|
.vscode/
|
||||||
|
*.swp
|
||||||
|
*.swo
|
||||||
|
|
||||||
|
# OS
|
||||||
|
.DS_Store
|
||||||
|
Thumbs.db
|
||||||
|
|
||||||
|
# Logs
|
||||||
|
*.log
|
||||||
|
|
||||||
|
# Local development
|
||||||
|
.env
|
||||||
|
.env.local
|
||||||
|
.env.*.local
|
||||||
|
|
||||||
|
# Coverage reports
|
||||||
|
htmlcov/
|
||||||
|
.coverage
|
||||||
|
.coverage.*
|
||||||
|
coverage.xml
|
||||||
|
*.cover
|
||||||
27
Dockerfile
Normal file
27
Dockerfile
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
FROM python:3.9-slim
|
||||||
|
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
# Install system dependencies
|
||||||
|
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||||
|
build-essential \
|
||||||
|
&& apt-get clean \
|
||||||
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
# Copy requirements file
|
||||||
|
COPY requirements.txt .
|
||||||
|
|
||||||
|
# Install Python dependencies
|
||||||
|
RUN pip install --no-cache-dir -r requirements.txt
|
||||||
|
|
||||||
|
# Copy application code
|
||||||
|
COPY . .
|
||||||
|
|
||||||
|
# Expose the port the app runs on
|
||||||
|
EXPOSE 8000
|
||||||
|
|
||||||
|
# Set environment variables
|
||||||
|
ENV PYTHONUNBUFFERED=1
|
||||||
|
|
||||||
|
# Command to run the application
|
||||||
|
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"]
|
||||||
156
README.md
156
README.md
@ -0,0 +1,156 @@
|
|||||||
|
# SEREACT - Secure Image Management API
|
||||||
|
|
||||||
|
SEREACT is a secure API for storing, organizing, and retrieving images with advanced search capabilities.
|
||||||
|
|
||||||
|
## Features
|
||||||
|
|
||||||
|
- Secure image storage in Google Cloud Storage
|
||||||
|
- Team-based organization and access control
|
||||||
|
- API key authentication
|
||||||
|
- Semantic search using image embeddings
|
||||||
|
- Metadata extraction and storage
|
||||||
|
- Image processing capabilities
|
||||||
|
- Multi-team support
|
||||||
|
|
||||||
|
## Architecture
|
||||||
|
|
||||||
|
```
|
||||||
|
sereact/
|
||||||
|
├── images/ # Sample images for testing
|
||||||
|
├── sereact/ # Main application code
|
||||||
|
│ ├── deployment/ # Deployment configurations
|
||||||
|
│ │ └── cloud-run/ # Google Cloud Run configuration
|
||||||
|
│ ├── docs/ # Documentation
|
||||||
|
│ │ └── api/ # API documentation
|
||||||
|
│ ├── scripts/ # Utility scripts
|
||||||
|
│ ├── src/ # Source code
|
||||||
|
│ │ ├── api/ # API endpoints
|
||||||
|
│ │ │ └── v1/ # API version 1
|
||||||
|
│ │ ├── core/ # Core modules
|
||||||
|
│ │ ├── db/ # Database models and repositories
|
||||||
|
│ │ │ ├── models/ # Data models
|
||||||
|
│ │ │ └── repositories/ # Database operations
|
||||||
|
│ │ ├── schemas/ # API schemas (request/response)
|
||||||
|
│ │ └── services/ # Business logic services
|
||||||
|
│ └── tests/ # Test code
|
||||||
|
│ ├── api/ # API tests
|
||||||
|
│ ├── db/ # Database tests
|
||||||
|
│ └── services/ # Service tests
|
||||||
|
├── main.py # Application entry point
|
||||||
|
├── requirements.txt # Python dependencies
|
||||||
|
└── README.md # This file
|
||||||
|
```
|
||||||
|
|
||||||
|
## Technology Stack
|
||||||
|
|
||||||
|
- FastAPI - Web framework
|
||||||
|
- MongoDB - Database
|
||||||
|
- Google Cloud Storage - Image storage
|
||||||
|
- Pinecone - Vector database for semantic search
|
||||||
|
- CLIP - Image embedding model
|
||||||
|
- PyTorch - Deep learning framework
|
||||||
|
- Pydantic - Data validation
|
||||||
|
|
||||||
|
## Setup and Installation
|
||||||
|
|
||||||
|
### Prerequisites
|
||||||
|
|
||||||
|
- Python 3.8+
|
||||||
|
- MongoDB
|
||||||
|
- Google Cloud account with Storage enabled
|
||||||
|
- (Optional) Pinecone account for semantic search
|
||||||
|
|
||||||
|
### Installation
|
||||||
|
|
||||||
|
1. Clone the repository:
|
||||||
|
```bash
|
||||||
|
git clone https://github.com/yourusername/sereact.git
|
||||||
|
cd sereact
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Create and activate a virtual environment:
|
||||||
|
```bash
|
||||||
|
python -m venv venv
|
||||||
|
source venv/bin/activate # Linux/macOS
|
||||||
|
venv\Scripts\activate # Windows
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Install dependencies:
|
||||||
|
```bash
|
||||||
|
pip install -r requirements.txt
|
||||||
|
```
|
||||||
|
|
||||||
|
4. Create a `.env` file with the following environment variables:
|
||||||
|
```
|
||||||
|
# MongoDB
|
||||||
|
DATABASE_URI=mongodb://localhost:27017
|
||||||
|
DATABASE_NAME=imagedb
|
||||||
|
|
||||||
|
# Google Cloud Storage
|
||||||
|
GCS_BUCKET_NAME=your-bucket-name
|
||||||
|
GCS_CREDENTIALS_FILE=path/to/credentials.json
|
||||||
|
|
||||||
|
# Security
|
||||||
|
API_KEY_SECRET=your-secret-key
|
||||||
|
|
||||||
|
# Vector database (optional)
|
||||||
|
VECTOR_DB_API_KEY=your-pinecone-api-key
|
||||||
|
VECTOR_DB_ENVIRONMENT=your-pinecone-environment
|
||||||
|
VECTOR_DB_INDEX_NAME=image-embeddings
|
||||||
|
```
|
||||||
|
|
||||||
|
5. Run the application:
|
||||||
|
```bash
|
||||||
|
uvicorn main:app --reload
|
||||||
|
```
|
||||||
|
|
||||||
|
6. Visit `http://localhost:8000/docs` in your browser to access the API documentation.
|
||||||
|
|
||||||
|
## API Endpoints
|
||||||
|
|
||||||
|
The API provides the following main endpoints:
|
||||||
|
|
||||||
|
- `/api/v1/auth/*` - Authentication and API key management
|
||||||
|
- `/api/v1/teams/*` - Team management
|
||||||
|
- `/api/v1/users/*` - User management
|
||||||
|
- `/api/v1/images/*` - Image upload, download, and management
|
||||||
|
- `/api/v1/search/*` - Image search functionality
|
||||||
|
|
||||||
|
Refer to the Swagger UI documentation at `/docs` for detailed endpoint information.
|
||||||
|
|
||||||
|
## Development
|
||||||
|
|
||||||
|
### Running Tests
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pytest
|
||||||
|
```
|
||||||
|
|
||||||
|
### Creating a New API Version
|
||||||
|
|
||||||
|
1. Create a new package under `src/api/` (e.g., `v2`)
|
||||||
|
2. Implement new endpoints
|
||||||
|
3. Update the main.py file to include the new routers
|
||||||
|
|
||||||
|
## Deployment
|
||||||
|
|
||||||
|
### Google Cloud Run
|
||||||
|
|
||||||
|
1. Build the Docker image:
|
||||||
|
```bash
|
||||||
|
docker build -t gcr.io/your-project/sereact .
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Push to Google Container Registry:
|
||||||
|
```bash
|
||||||
|
docker push gcr.io/your-project/sereact
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Deploy to Cloud Run:
|
||||||
|
```bash
|
||||||
|
gcloud run deploy sereact --image gcr.io/your-project/sereact --platform managed
|
||||||
|
```
|
||||||
|
|
||||||
|
## License
|
||||||
|
|
||||||
|
This project is licensed under the MIT License - see the LICENSE file for details.
|
||||||
41
deployment/cloud-run/service.yaml
Normal file
41
deployment/cloud-run/service.yaml
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
apiVersion: serving.knative.dev/v1
|
||||||
|
kind: Service
|
||||||
|
metadata:
|
||||||
|
name: sereact
|
||||||
|
spec:
|
||||||
|
template:
|
||||||
|
spec:
|
||||||
|
containers:
|
||||||
|
- image: gcr.io/your-project/sereact:latest
|
||||||
|
ports:
|
||||||
|
- containerPort: 8000
|
||||||
|
resources:
|
||||||
|
limits:
|
||||||
|
cpu: "1"
|
||||||
|
memory: "1Gi"
|
||||||
|
env:
|
||||||
|
- name: DATABASE_URI
|
||||||
|
valueFrom:
|
||||||
|
secretKeyRef:
|
||||||
|
name: sereact-db-uri
|
||||||
|
key: latest
|
||||||
|
- name: DATABASE_NAME
|
||||||
|
value: "imagedb"
|
||||||
|
- name: GCS_BUCKET_NAME
|
||||||
|
value: "your-bucket-name"
|
||||||
|
- name: API_KEY_SECRET
|
||||||
|
valueFrom:
|
||||||
|
secretKeyRef:
|
||||||
|
name: sereact-api-key-secret
|
||||||
|
key: latest
|
||||||
|
- name: VECTOR_DB_API_KEY
|
||||||
|
valueFrom:
|
||||||
|
secretKeyRef:
|
||||||
|
name: sereact-vector-db-key
|
||||||
|
key: latest
|
||||||
|
- name: VECTOR_DB_ENVIRONMENT
|
||||||
|
value: "your-pinecone-env"
|
||||||
|
- name: VECTOR_DB_INDEX_NAME
|
||||||
|
value: "image-embeddings"
|
||||||
|
- name: LOG_LEVEL
|
||||||
|
value: "INFO"
|
||||||
104
main.py
Normal file
104
main.py
Normal file
@ -0,0 +1,104 @@
|
|||||||
|
import logging
|
||||||
|
from fastapi import FastAPI, Request
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
from fastapi.openapi.docs import get_swagger_ui_html
|
||||||
|
from fastapi.openapi.utils import get_openapi
|
||||||
|
|
||||||
|
# Import API routers
|
||||||
|
from src.api.v1 import teams, users, images, auth, search
|
||||||
|
|
||||||
|
# Import configuration
|
||||||
|
from src.core.config import settings
|
||||||
|
from src.core.logging import setup_logging
|
||||||
|
|
||||||
|
# Setup logging
|
||||||
|
setup_logging()
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Create FastAPI app
|
||||||
|
app = FastAPI(
|
||||||
|
title=settings.PROJECT_NAME,
|
||||||
|
description="API for securely storing, organizing, and retrieving images",
|
||||||
|
version="1.0.0",
|
||||||
|
docs_url=None,
|
||||||
|
redoc_url=None,
|
||||||
|
openapi_url="/api/v1/openapi.json"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Set up CORS
|
||||||
|
app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=settings.CORS_ORIGINS,
|
||||||
|
allow_credentials=True,
|
||||||
|
allow_methods=["*"],
|
||||||
|
allow_headers=["*"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Include API routers
|
||||||
|
app.include_router(auth.router, prefix="/api/v1")
|
||||||
|
app.include_router(teams.router, prefix="/api/v1")
|
||||||
|
app.include_router(users.router, prefix="/api/v1")
|
||||||
|
app.include_router(images.router, prefix="/api/v1")
|
||||||
|
app.include_router(search.router, prefix="/api/v1")
|
||||||
|
|
||||||
|
# Custom exception handler
|
||||||
|
@app.exception_handler(Exception)
|
||||||
|
async def general_exception_handler(request: Request, exc: Exception):
|
||||||
|
logger.error(f"Unhandled exception: {exc}", exc_info=True)
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=500,
|
||||||
|
content={"detail": "Internal server error"}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Custom Swagger UI with API key authentication
|
||||||
|
@app.get("/docs", include_in_schema=False)
|
||||||
|
async def custom_swagger_ui_html():
|
||||||
|
return get_swagger_ui_html(
|
||||||
|
openapi_url=app.openapi_url,
|
||||||
|
title=f"{app.title} - Swagger UI",
|
||||||
|
oauth2_redirect_url=app.swagger_ui_oauth2_redirect_url,
|
||||||
|
swagger_js_url="https://cdn.jsdelivr.net/npm/swagger-ui-dist@5.9.0/swagger-ui-bundle.js",
|
||||||
|
swagger_css_url="https://cdn.jsdelivr.net/npm/swagger-ui-dist@5.9.0/swagger-ui.css",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Custom OpenAPI schema to include API key auth
|
||||||
|
def custom_openapi():
|
||||||
|
if app.openapi_schema:
|
||||||
|
return app.openapi_schema
|
||||||
|
|
||||||
|
openapi_schema = get_openapi(
|
||||||
|
title=app.title,
|
||||||
|
version=app.version,
|
||||||
|
description=app.description,
|
||||||
|
routes=app.routes,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add API key security scheme
|
||||||
|
openapi_schema["components"] = {
|
||||||
|
"securitySchemes": {
|
||||||
|
"ApiKeyAuth": {
|
||||||
|
"type": "apiKey",
|
||||||
|
"in": "header",
|
||||||
|
"name": "X-API-Key"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
# Apply security to all endpoints except auth endpoints
|
||||||
|
for path in openapi_schema["paths"]:
|
||||||
|
if not path.startswith("/api/v1/auth"):
|
||||||
|
openapi_schema["paths"][path]["security"] = [{"ApiKeyAuth": []}]
|
||||||
|
|
||||||
|
app.openapi_schema = openapi_schema
|
||||||
|
return app.openapi_schema
|
||||||
|
|
||||||
|
app.openapi = custom_openapi
|
||||||
|
|
||||||
|
@app.get("/", include_in_schema=False)
|
||||||
|
async def root():
|
||||||
|
return {"message": "Welcome to the Image Management API. Please see /docs for API documentation."}
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import uvicorn
|
||||||
|
uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True)
|
||||||
21
requirements.txt
Normal file
21
requirements.txt
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
fastapi==0.104.1
|
||||||
|
uvicorn==0.23.2
|
||||||
|
pydantic==2.4.2
|
||||||
|
pydantic-settings==2.0.3
|
||||||
|
python-dotenv==1.0.0
|
||||||
|
google-cloud-storage==2.12.0
|
||||||
|
google-cloud-vision==3.4.5
|
||||||
|
pymongo==4.5.0
|
||||||
|
motor==3.3.1
|
||||||
|
python-multipart==0.0.6
|
||||||
|
python-jose==3.3.0
|
||||||
|
passlib==1.7.4
|
||||||
|
tenacity==8.2.3
|
||||||
|
pytest==7.4.3
|
||||||
|
httpx==0.25.1
|
||||||
|
pinecone-client==2.2.4
|
||||||
|
pillow==10.1.0
|
||||||
|
clip==0.2.0
|
||||||
|
torch==2.1.0
|
||||||
|
transformers==4.35.0
|
||||||
|
python-slugify==8.0.1
|
||||||
0
src/__init__.py
Normal file
0
src/__init__.py
Normal file
0
src/api/__init__.py
Normal file
0
src/api/__init__.py
Normal file
0
src/api/v1/__init__.py
Normal file
0
src/api/v1/__init__.py
Normal file
186
src/api/v1/auth.py
Normal file
186
src/api/v1/auth.py
Normal file
@ -0,0 +1,186 @@
|
|||||||
|
import logging
|
||||||
|
from typing import Optional
|
||||||
|
from datetime import datetime
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, Header, Request
|
||||||
|
from bson import ObjectId
|
||||||
|
|
||||||
|
from src.db.repositories.api_key_repository import api_key_repository
|
||||||
|
from src.db.repositories.user_repository import user_repository
|
||||||
|
from src.db.repositories.team_repository import team_repository
|
||||||
|
from src.schemas.api_key import ApiKeyCreate, ApiKeyResponse, ApiKeyWithValueResponse, ApiKeyListResponse
|
||||||
|
from src.core.security import generate_api_key, verify_api_key, calculate_expiry_date, is_expired, hash_api_key
|
||||||
|
from src.db.models.api_key import ApiKeyModel
|
||||||
|
from src.core.logging import log_request
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
router = APIRouter(tags=["Authentication"], prefix="/auth")
|
||||||
|
|
||||||
|
async def get_current_user(x_api_key: Optional[str] = Header(None)):
|
||||||
|
"""
|
||||||
|
Get the current user from API key
|
||||||
|
"""
|
||||||
|
if not x_api_key:
|
||||||
|
raise HTTPException(status_code=401, detail="API key is required")
|
||||||
|
|
||||||
|
# Hash the API key
|
||||||
|
hashed_key = hash_api_key(x_api_key)
|
||||||
|
|
||||||
|
# Get the key from the database
|
||||||
|
api_key = await api_key_repository.get_by_hash(hashed_key)
|
||||||
|
if not api_key:
|
||||||
|
raise HTTPException(status_code=401, detail="Invalid API key")
|
||||||
|
|
||||||
|
# Check if the key is active
|
||||||
|
if not api_key.is_active:
|
||||||
|
raise HTTPException(status_code=401, detail="API key is inactive")
|
||||||
|
|
||||||
|
# Check if the key has expired
|
||||||
|
if api_key.expiry_date and is_expired(api_key.expiry_date):
|
||||||
|
raise HTTPException(status_code=401, detail="API key has expired")
|
||||||
|
|
||||||
|
# Get the user
|
||||||
|
user = await user_repository.get_by_id(api_key.user_id)
|
||||||
|
if not user:
|
||||||
|
raise HTTPException(status_code=401, detail="User not found")
|
||||||
|
|
||||||
|
# Check if the user is active
|
||||||
|
if not user.is_active:
|
||||||
|
raise HTTPException(status_code=401, detail="User is inactive")
|
||||||
|
|
||||||
|
# Update last used timestamp
|
||||||
|
await api_key_repository.update_last_used(api_key.id)
|
||||||
|
|
||||||
|
return user
|
||||||
|
|
||||||
|
@router.post("/api-keys", response_model=ApiKeyWithValueResponse, status_code=201)
|
||||||
|
async def create_api_key(key_data: ApiKeyCreate, request: Request, current_user = Depends(get_current_user)):
|
||||||
|
"""
|
||||||
|
Create a new API key
|
||||||
|
"""
|
||||||
|
log_request(
|
||||||
|
{"path": request.url.path, "method": request.method, "key_data": key_data.dict()},
|
||||||
|
user_id=str(current_user.id),
|
||||||
|
team_id=str(current_user.team_id)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if user's team exists
|
||||||
|
team = await team_repository.get_by_id(current_user.team_id)
|
||||||
|
if not team:
|
||||||
|
raise HTTPException(status_code=404, detail="Team not found")
|
||||||
|
|
||||||
|
# Generate API key with expiry date
|
||||||
|
raw_key, hashed_key = generate_api_key(str(current_user.team_id), str(current_user.id))
|
||||||
|
expiry_date = calculate_expiry_date()
|
||||||
|
|
||||||
|
# Create API key in database
|
||||||
|
api_key = ApiKeyModel(
|
||||||
|
key_hash=hashed_key,
|
||||||
|
user_id=current_user.id,
|
||||||
|
team_id=current_user.team_id,
|
||||||
|
name=key_data.name,
|
||||||
|
description=key_data.description,
|
||||||
|
expiry_date=expiry_date,
|
||||||
|
is_active=True
|
||||||
|
)
|
||||||
|
|
||||||
|
created_key = await api_key_repository.create(api_key)
|
||||||
|
|
||||||
|
# Convert to response model
|
||||||
|
response = ApiKeyWithValueResponse(
|
||||||
|
id=str(created_key.id),
|
||||||
|
key=raw_key,
|
||||||
|
name=created_key.name,
|
||||||
|
description=created_key.description,
|
||||||
|
team_id=str(created_key.team_id),
|
||||||
|
user_id=str(created_key.user_id),
|
||||||
|
created_at=created_key.created_at,
|
||||||
|
expiry_date=created_key.expiry_date,
|
||||||
|
last_used=created_key.last_used,
|
||||||
|
is_active=created_key.is_active
|
||||||
|
)
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
@router.get("/api-keys", response_model=ApiKeyListResponse)
|
||||||
|
async def list_api_keys(request: Request, current_user = Depends(get_current_user)):
|
||||||
|
"""
|
||||||
|
List API keys for the current user
|
||||||
|
"""
|
||||||
|
log_request(
|
||||||
|
{"path": request.url.path, "method": request.method},
|
||||||
|
user_id=str(current_user.id),
|
||||||
|
team_id=str(current_user.team_id)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get API keys for user
|
||||||
|
keys = await api_key_repository.get_by_user(current_user.id)
|
||||||
|
|
||||||
|
# Convert to response models
|
||||||
|
response_keys = []
|
||||||
|
for key in keys:
|
||||||
|
response_keys.append(ApiKeyResponse(
|
||||||
|
id=str(key.id),
|
||||||
|
name=key.name,
|
||||||
|
description=key.description,
|
||||||
|
team_id=str(key.team_id),
|
||||||
|
user_id=str(key.user_id),
|
||||||
|
created_at=key.created_at,
|
||||||
|
expiry_date=key.expiry_date,
|
||||||
|
last_used=key.last_used,
|
||||||
|
is_active=key.is_active
|
||||||
|
))
|
||||||
|
|
||||||
|
return ApiKeyListResponse(api_keys=response_keys, total=len(response_keys))
|
||||||
|
|
||||||
|
@router.delete("/api-keys/{key_id}", status_code=204)
|
||||||
|
async def revoke_api_key(key_id: str, request: Request, current_user = Depends(get_current_user)):
|
||||||
|
"""
|
||||||
|
Revoke (deactivate) an API key
|
||||||
|
"""
|
||||||
|
log_request(
|
||||||
|
{"path": request.url.path, "method": request.method, "key_id": key_id},
|
||||||
|
user_id=str(current_user.id),
|
||||||
|
team_id=str(current_user.team_id)
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Convert string ID to ObjectId
|
||||||
|
obj_id = ObjectId(key_id)
|
||||||
|
except:
|
||||||
|
raise HTTPException(status_code=400, detail="Invalid key ID")
|
||||||
|
|
||||||
|
# Get the API key
|
||||||
|
key = await api_key_repository.get_by_id(obj_id)
|
||||||
|
if not key:
|
||||||
|
raise HTTPException(status_code=404, detail="API key not found")
|
||||||
|
|
||||||
|
# Check if user owns the key or is an admin
|
||||||
|
if key.user_id != current_user.id and not current_user.is_admin:
|
||||||
|
raise HTTPException(status_code=403, detail="Not authorized to revoke this API key")
|
||||||
|
|
||||||
|
# Deactivate the key
|
||||||
|
result = await api_key_repository.deactivate(obj_id)
|
||||||
|
if not result:
|
||||||
|
raise HTTPException(status_code=500, detail="Failed to revoke API key")
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
@router.get("/verify", status_code=200)
|
||||||
|
async def verify_authentication(request: Request, current_user = Depends(get_current_user)):
|
||||||
|
"""
|
||||||
|
Verify the current authentication (API key)
|
||||||
|
"""
|
||||||
|
log_request(
|
||||||
|
{"path": request.url.path, "method": request.method},
|
||||||
|
user_id=str(current_user.id),
|
||||||
|
team_id=str(current_user.team_id)
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"user_id": str(current_user.id),
|
||||||
|
"name": current_user.name,
|
||||||
|
"email": current_user.email,
|
||||||
|
"team_id": str(current_user.team_id),
|
||||||
|
"is_admin": current_user.is_admin
|
||||||
|
}
|
||||||
215
src/api/v1/teams.py
Normal file
215
src/api/v1/teams.py
Normal file
@ -0,0 +1,215 @@
|
|||||||
|
import logging
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||||
|
from bson import ObjectId
|
||||||
|
|
||||||
|
from src.db.repositories.team_repository import team_repository
|
||||||
|
from src.schemas.team import TeamCreate, TeamUpdate, TeamResponse, TeamListResponse
|
||||||
|
from src.db.models.team import TeamModel
|
||||||
|
from src.api.v1.auth import get_current_user
|
||||||
|
from src.core.logging import log_request
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
router = APIRouter(tags=["Teams"], prefix="/teams")
|
||||||
|
|
||||||
|
@router.post("", response_model=TeamResponse, status_code=201)
|
||||||
|
async def create_team(team_data: TeamCreate, request: Request, current_user = Depends(get_current_user)):
|
||||||
|
"""
|
||||||
|
Create a new team
|
||||||
|
|
||||||
|
This endpoint requires admin privileges
|
||||||
|
"""
|
||||||
|
log_request(
|
||||||
|
{"path": request.url.path, "method": request.method, "team_data": team_data.dict()},
|
||||||
|
user_id=str(current_user.id),
|
||||||
|
team_id=str(current_user.team_id)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Only admins can create teams
|
||||||
|
if not current_user.is_admin:
|
||||||
|
raise HTTPException(status_code=403, detail="Only admins can create teams")
|
||||||
|
|
||||||
|
# Create team
|
||||||
|
team = TeamModel(
|
||||||
|
name=team_data.name,
|
||||||
|
description=team_data.description
|
||||||
|
)
|
||||||
|
|
||||||
|
created_team = await team_repository.create(team)
|
||||||
|
|
||||||
|
# Convert to response model
|
||||||
|
response = TeamResponse(
|
||||||
|
id=str(created_team.id),
|
||||||
|
name=created_team.name,
|
||||||
|
description=created_team.description,
|
||||||
|
created_at=created_team.created_at,
|
||||||
|
updated_at=created_team.updated_at
|
||||||
|
)
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
@router.get("", response_model=TeamListResponse)
|
||||||
|
async def list_teams(request: Request, current_user = Depends(get_current_user)):
|
||||||
|
"""
|
||||||
|
List all teams
|
||||||
|
|
||||||
|
This endpoint requires admin privileges
|
||||||
|
"""
|
||||||
|
log_request(
|
||||||
|
{"path": request.url.path, "method": request.method},
|
||||||
|
user_id=str(current_user.id),
|
||||||
|
team_id=str(current_user.team_id)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Only admins can list all teams
|
||||||
|
if not current_user.is_admin:
|
||||||
|
raise HTTPException(status_code=403, detail="Only admins can list all teams")
|
||||||
|
|
||||||
|
# Get all teams
|
||||||
|
teams = await team_repository.get_all()
|
||||||
|
|
||||||
|
# Convert to response models
|
||||||
|
response_teams = []
|
||||||
|
for team in teams:
|
||||||
|
response_teams.append(TeamResponse(
|
||||||
|
id=str(team.id),
|
||||||
|
name=team.name,
|
||||||
|
description=team.description,
|
||||||
|
created_at=team.created_at,
|
||||||
|
updated_at=team.updated_at
|
||||||
|
))
|
||||||
|
|
||||||
|
return TeamListResponse(teams=response_teams, total=len(response_teams))
|
||||||
|
|
||||||
|
@router.get("/{team_id}", response_model=TeamResponse)
|
||||||
|
async def get_team(team_id: str, request: Request, current_user = Depends(get_current_user)):
|
||||||
|
"""
|
||||||
|
Get a team by ID
|
||||||
|
|
||||||
|
Users can only access their own team unless they are an admin
|
||||||
|
"""
|
||||||
|
log_request(
|
||||||
|
{"path": request.url.path, "method": request.method, "team_id": team_id},
|
||||||
|
user_id=str(current_user.id),
|
||||||
|
team_id=str(current_user.team_id)
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Convert string ID to ObjectId
|
||||||
|
obj_id = ObjectId(team_id)
|
||||||
|
except:
|
||||||
|
raise HTTPException(status_code=400, detail="Invalid team ID")
|
||||||
|
|
||||||
|
# Get the team
|
||||||
|
team = await team_repository.get_by_id(obj_id)
|
||||||
|
if not team:
|
||||||
|
raise HTTPException(status_code=404, detail="Team not found")
|
||||||
|
|
||||||
|
# Check if user can access this team
|
||||||
|
if str(team.id) != str(current_user.team_id) and not current_user.is_admin:
|
||||||
|
raise HTTPException(status_code=403, detail="Not authorized to access this team")
|
||||||
|
|
||||||
|
# Convert to response model
|
||||||
|
response = TeamResponse(
|
||||||
|
id=str(team.id),
|
||||||
|
name=team.name,
|
||||||
|
description=team.description,
|
||||||
|
created_at=team.created_at,
|
||||||
|
updated_at=team.updated_at
|
||||||
|
)
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
@router.put("/{team_id}", response_model=TeamResponse)
|
||||||
|
async def update_team(team_id: str, team_data: TeamUpdate, request: Request, current_user = Depends(get_current_user)):
|
||||||
|
"""
|
||||||
|
Update a team
|
||||||
|
|
||||||
|
This endpoint requires admin privileges
|
||||||
|
"""
|
||||||
|
log_request(
|
||||||
|
{"path": request.url.path, "method": request.method, "team_id": team_id, "team_data": team_data.dict()},
|
||||||
|
user_id=str(current_user.id),
|
||||||
|
team_id=str(current_user.team_id)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Only admins can update teams
|
||||||
|
if not current_user.is_admin:
|
||||||
|
raise HTTPException(status_code=403, detail="Only admins can update teams")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Convert string ID to ObjectId
|
||||||
|
obj_id = ObjectId(team_id)
|
||||||
|
except:
|
||||||
|
raise HTTPException(status_code=400, detail="Invalid team ID")
|
||||||
|
|
||||||
|
# Get the team
|
||||||
|
team = await team_repository.get_by_id(obj_id)
|
||||||
|
if not team:
|
||||||
|
raise HTTPException(status_code=404, detail="Team not found")
|
||||||
|
|
||||||
|
# Update the team
|
||||||
|
update_data = team_data.dict(exclude_unset=True)
|
||||||
|
if not update_data:
|
||||||
|
# No fields to update
|
||||||
|
return TeamResponse(
|
||||||
|
id=str(team.id),
|
||||||
|
name=team.name,
|
||||||
|
description=team.description,
|
||||||
|
created_at=team.created_at,
|
||||||
|
updated_at=team.updated_at
|
||||||
|
)
|
||||||
|
|
||||||
|
updated_team = await team_repository.update(obj_id, update_data)
|
||||||
|
if not updated_team:
|
||||||
|
raise HTTPException(status_code=500, detail="Failed to update team")
|
||||||
|
|
||||||
|
# Convert to response model
|
||||||
|
response = TeamResponse(
|
||||||
|
id=str(updated_team.id),
|
||||||
|
name=updated_team.name,
|
||||||
|
description=updated_team.description,
|
||||||
|
created_at=updated_team.created_at,
|
||||||
|
updated_at=updated_team.updated_at
|
||||||
|
)
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
@router.delete("/{team_id}", status_code=204)
|
||||||
|
async def delete_team(team_id: str, request: Request, current_user = Depends(get_current_user)):
|
||||||
|
"""
|
||||||
|
Delete a team
|
||||||
|
|
||||||
|
This endpoint requires admin privileges
|
||||||
|
"""
|
||||||
|
log_request(
|
||||||
|
{"path": request.url.path, "method": request.method, "team_id": team_id},
|
||||||
|
user_id=str(current_user.id),
|
||||||
|
team_id=str(current_user.team_id)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Only admins can delete teams
|
||||||
|
if not current_user.is_admin:
|
||||||
|
raise HTTPException(status_code=403, detail="Only admins can delete teams")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Convert string ID to ObjectId
|
||||||
|
obj_id = ObjectId(team_id)
|
||||||
|
except:
|
||||||
|
raise HTTPException(status_code=400, detail="Invalid team ID")
|
||||||
|
|
||||||
|
# Check if team exists
|
||||||
|
team = await team_repository.get_by_id(obj_id)
|
||||||
|
if not team:
|
||||||
|
raise HTTPException(status_code=404, detail="Team not found")
|
||||||
|
|
||||||
|
# Don't allow deleting a user's own team
|
||||||
|
if str(team.id) == str(current_user.team_id):
|
||||||
|
raise HTTPException(status_code=400, detail="Cannot delete your own team")
|
||||||
|
|
||||||
|
# Delete the team
|
||||||
|
result = await team_repository.delete(obj_id)
|
||||||
|
if not result:
|
||||||
|
raise HTTPException(status_code=500, detail="Failed to delete team")
|
||||||
|
|
||||||
|
return None
|
||||||
0
src/core/__init__.py
Normal file
0
src/core/__init__.py
Normal file
47
src/core/config.py
Normal file
47
src/core/config.py
Normal file
@ -0,0 +1,47 @@
|
|||||||
|
import os
|
||||||
|
from typing import List
|
||||||
|
from pydantic_settings import BaseSettings
|
||||||
|
from pydantic import AnyHttpUrl, validator
|
||||||
|
|
||||||
|
class Settings(BaseSettings):
|
||||||
|
# Project settings
|
||||||
|
PROJECT_NAME: str = "Image Management API"
|
||||||
|
API_V1_STR: str = "/api/v1"
|
||||||
|
|
||||||
|
# CORS settings
|
||||||
|
CORS_ORIGINS: List[str] = ["*"]
|
||||||
|
|
||||||
|
@validator("CORS_ORIGINS", pre=True)
|
||||||
|
def assemble_cors_origins(cls, v):
|
||||||
|
if isinstance(v, str) and not v.startswith("["):
|
||||||
|
return [i.strip() for i in v.split(",")]
|
||||||
|
return v
|
||||||
|
|
||||||
|
# Database settings
|
||||||
|
DATABASE_URI: str = os.getenv("DATABASE_URI", "mongodb://localhost:27017")
|
||||||
|
DATABASE_NAME: str = os.getenv("DATABASE_NAME", "imagedb")
|
||||||
|
|
||||||
|
# Google Cloud Storage settings
|
||||||
|
GCS_BUCKET_NAME: str = os.getenv("GCS_BUCKET_NAME", "image-mgmt-bucket")
|
||||||
|
GCS_CREDENTIALS_FILE: str = os.getenv("GCS_CREDENTIALS_FILE", "credentials.json")
|
||||||
|
|
||||||
|
# Security settings
|
||||||
|
API_KEY_SECRET: str = os.getenv("API_KEY_SECRET", "super-secret-key-for-development-only")
|
||||||
|
API_KEY_EXPIRY_DAYS: int = int(os.getenv("API_KEY_EXPIRY_DAYS", "365"))
|
||||||
|
|
||||||
|
# Vector Database settings (for image embeddings)
|
||||||
|
VECTOR_DB_API_KEY: str = os.getenv("VECTOR_DB_API_KEY", "")
|
||||||
|
VECTOR_DB_ENVIRONMENT: str = os.getenv("VECTOR_DB_ENVIRONMENT", "")
|
||||||
|
VECTOR_DB_INDEX_NAME: str = os.getenv("VECTOR_DB_INDEX_NAME", "image-embeddings")
|
||||||
|
|
||||||
|
# Rate limiting
|
||||||
|
RATE_LIMIT_PER_MINUTE: int = int(os.getenv("RATE_LIMIT_PER_MINUTE", "100"))
|
||||||
|
|
||||||
|
# Logging
|
||||||
|
LOG_LEVEL: str = os.getenv("LOG_LEVEL", "INFO")
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
case_sensitive = True
|
||||||
|
env_file = ".env"
|
||||||
|
|
||||||
|
settings = Settings()
|
||||||
81
src/core/logging.py
Normal file
81
src/core/logging.py
Normal file
@ -0,0 +1,81 @@
|
|||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
from typing import Dict, Any
|
||||||
|
|
||||||
|
from src.core.config import settings
|
||||||
|
|
||||||
|
def setup_logging():
|
||||||
|
"""Configure logging settings for the application"""
|
||||||
|
log_format = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||||
|
log_level = getattr(logging, settings.LOG_LEVEL.upper())
|
||||||
|
|
||||||
|
# Configure root logger
|
||||||
|
logging.basicConfig(
|
||||||
|
level=log_level,
|
||||||
|
format=log_format,
|
||||||
|
handlers=[
|
||||||
|
logging.StreamHandler(sys.stdout)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Set log levels for specific modules
|
||||||
|
logging.getLogger("uvicorn").setLevel(logging.INFO)
|
||||||
|
logging.getLogger("uvicorn.access").setLevel(logging.INFO)
|
||||||
|
logging.getLogger("fastapi").setLevel(logging.INFO)
|
||||||
|
|
||||||
|
# Reduce noise from third-party libraries
|
||||||
|
logging.getLogger("asyncio").setLevel(logging.WARNING)
|
||||||
|
logging.getLogger("botocore").setLevel(logging.WARNING)
|
||||||
|
logging.getLogger("google").setLevel(logging.WARNING)
|
||||||
|
|
||||||
|
# Create a logger for this module
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
logger.debug("Logging configured successfully")
|
||||||
|
|
||||||
|
return logger
|
||||||
|
|
||||||
|
def log_request(request_data: Dict[str, Any], user_id: str = None, team_id: str = None):
|
||||||
|
"""
|
||||||
|
Log API request data with user and team context
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request_data: Dictionary with request details
|
||||||
|
user_id: Optional user ID
|
||||||
|
team_id: Optional team ID
|
||||||
|
"""
|
||||||
|
logger = logging.getLogger("api.request")
|
||||||
|
|
||||||
|
log_data = {
|
||||||
|
"request": request_data,
|
||||||
|
}
|
||||||
|
|
||||||
|
if user_id:
|
||||||
|
log_data["user_id"] = user_id
|
||||||
|
|
||||||
|
if team_id:
|
||||||
|
log_data["team_id"] = team_id
|
||||||
|
|
||||||
|
logger.info(f"API Request: {log_data}")
|
||||||
|
|
||||||
|
def log_error(error_message: str, error: Exception = None, context: Dict[str, Any] = None):
|
||||||
|
"""
|
||||||
|
Log error with context information
|
||||||
|
|
||||||
|
Args:
|
||||||
|
error_message: Human-readable error message
|
||||||
|
error: Optional exception object
|
||||||
|
context: Optional dictionary with context data
|
||||||
|
"""
|
||||||
|
logger = logging.getLogger("api.error")
|
||||||
|
|
||||||
|
log_data = {
|
||||||
|
"message": error_message,
|
||||||
|
}
|
||||||
|
|
||||||
|
if context:
|
||||||
|
log_data["context"] = context
|
||||||
|
|
||||||
|
if error:
|
||||||
|
logger.error(f"Error: {log_data}", exc_info=error)
|
||||||
|
else:
|
||||||
|
logger.error(f"Error: {log_data}")
|
||||||
91
src/core/security.py
Normal file
91
src/core/security.py
Normal file
@ -0,0 +1,91 @@
|
|||||||
|
import secrets
|
||||||
|
import string
|
||||||
|
import time
|
||||||
|
import hmac
|
||||||
|
import hashlib
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
from src.core.config import settings
|
||||||
|
|
||||||
|
def generate_api_key(team_id: str, user_id: str) -> Tuple[str, str]:
|
||||||
|
"""
|
||||||
|
Generate a secure API key and its hashed value
|
||||||
|
|
||||||
|
Args:
|
||||||
|
team_id: Team ID for which the key is generated
|
||||||
|
user_id: User ID for which the key is generated
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (raw_api_key, hashed_api_key)
|
||||||
|
"""
|
||||||
|
# Generate a random key prefix (visible part)
|
||||||
|
prefix = ''.join(secrets.choice(string.ascii_letters + string.digits) for _ in range(8))
|
||||||
|
|
||||||
|
# Generate a secure random token for the key
|
||||||
|
random_part = secrets.token_hex(16)
|
||||||
|
|
||||||
|
# Format: prefix.random_part
|
||||||
|
raw_api_key = f"{prefix}.{random_part}"
|
||||||
|
|
||||||
|
# Hash the API key for storage
|
||||||
|
hashed_api_key = hash_api_key(raw_api_key)
|
||||||
|
|
||||||
|
return raw_api_key, hashed_api_key
|
||||||
|
|
||||||
|
def hash_api_key(api_key: str) -> str:
|
||||||
|
"""
|
||||||
|
Create a secure hash of the API key for storage
|
||||||
|
|
||||||
|
Args:
|
||||||
|
api_key: The raw API key
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Hashed API key
|
||||||
|
"""
|
||||||
|
return hmac.new(
|
||||||
|
settings.API_KEY_SECRET.encode(),
|
||||||
|
api_key.encode(),
|
||||||
|
hashlib.sha256
|
||||||
|
).hexdigest()
|
||||||
|
|
||||||
|
def verify_api_key(api_key: str, hashed_api_key: str) -> bool:
|
||||||
|
"""
|
||||||
|
Verify if the provided API key matches the stored hash
|
||||||
|
|
||||||
|
Args:
|
||||||
|
api_key: The API key to verify
|
||||||
|
hashed_api_key: The stored hash
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the API key is valid
|
||||||
|
"""
|
||||||
|
calculated_hash = hash_api_key(api_key)
|
||||||
|
return hmac.compare_digest(calculated_hash, hashed_api_key)
|
||||||
|
|
||||||
|
def calculate_expiry_date(days: Optional[int] = None) -> datetime:
|
||||||
|
"""
|
||||||
|
Calculate the expiry date for API keys
|
||||||
|
|
||||||
|
Args:
|
||||||
|
days: Optional number of days until expiry
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Expiry date
|
||||||
|
"""
|
||||||
|
if days is None:
|
||||||
|
days = settings.API_KEY_EXPIRY_DAYS
|
||||||
|
|
||||||
|
return datetime.utcnow() + timedelta(days=days)
|
||||||
|
|
||||||
|
def is_expired(expiry_date: datetime) -> bool:
|
||||||
|
"""
|
||||||
|
Check if an API key has expired
|
||||||
|
|
||||||
|
Args:
|
||||||
|
expiry_date: The expiry date
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if expired
|
||||||
|
"""
|
||||||
|
return datetime.utcnow() > expiry_date
|
||||||
33
src/db/__init__.py
Normal file
33
src/db/__init__.py
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
import logging
|
||||||
|
from motor.motor_asyncio import AsyncIOMotorClient
|
||||||
|
from src.core.config import settings
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class Database:
|
||||||
|
client: AsyncIOMotorClient = None
|
||||||
|
|
||||||
|
def connect_to_database(self):
|
||||||
|
"""Create database connection."""
|
||||||
|
try:
|
||||||
|
self.client = AsyncIOMotorClient(settings.DATABASE_URI)
|
||||||
|
logger.info("Connected to MongoDB")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to connect to MongoDB: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def close_database_connection(self):
|
||||||
|
"""Close database connection."""
|
||||||
|
try:
|
||||||
|
if self.client:
|
||||||
|
self.client.close()
|
||||||
|
logger.info("Closed MongoDB connection")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to close MongoDB connection: {e}")
|
||||||
|
|
||||||
|
def get_database(self):
|
||||||
|
"""Get the database instance."""
|
||||||
|
return self.client[settings.DATABASE_NAME]
|
||||||
|
|
||||||
|
# Create a singleton database instance
|
||||||
|
db = Database()
|
||||||
0
src/db/models/__init__.py
Normal file
0
src/db/models/__init__.py
Normal file
26
src/db/models/api_key.py
Normal file
26
src/db/models/api_key.py
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
from datetime import datetime
|
||||||
|
from typing import Optional
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from bson import ObjectId
|
||||||
|
|
||||||
|
from src.db.models.team import PyObjectId
|
||||||
|
|
||||||
|
class ApiKeyModel(BaseModel):
|
||||||
|
"""Database model for an API key"""
|
||||||
|
id: Optional[PyObjectId] = Field(default_factory=PyObjectId, alias="_id")
|
||||||
|
key_hash: str
|
||||||
|
user_id: PyObjectId
|
||||||
|
team_id: PyObjectId
|
||||||
|
name: str
|
||||||
|
description: Optional[str] = None
|
||||||
|
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||||
|
expiry_date: Optional[datetime] = None
|
||||||
|
last_used: Optional[datetime] = None
|
||||||
|
is_active: bool = True
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
allow_population_by_field_name = True
|
||||||
|
arbitrary_types_allowed = True
|
||||||
|
json_encoders = {
|
||||||
|
ObjectId: str
|
||||||
|
}
|
||||||
35
src/db/models/image.py
Normal file
35
src/db/models/image.py
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
from datetime import datetime
|
||||||
|
from typing import Optional, List, Dict, Any
|
||||||
|
from pydantic import BaseModel, Field, HttpUrl
|
||||||
|
from bson import ObjectId
|
||||||
|
|
||||||
|
from src.db.models.team import PyObjectId
|
||||||
|
|
||||||
|
class ImageModel(BaseModel):
|
||||||
|
"""Database model for an image"""
|
||||||
|
id: Optional[PyObjectId] = Field(default_factory=PyObjectId, alias="_id")
|
||||||
|
filename: str
|
||||||
|
original_filename: str
|
||||||
|
file_size: int
|
||||||
|
content_type: str
|
||||||
|
storage_path: str
|
||||||
|
public_url: Optional[HttpUrl] = None
|
||||||
|
team_id: PyObjectId
|
||||||
|
uploader_id: PyObjectId
|
||||||
|
upload_date: datetime = Field(default_factory=datetime.utcnow)
|
||||||
|
last_accessed: Optional[datetime] = None
|
||||||
|
description: Optional[str] = None
|
||||||
|
tags: List[str] = []
|
||||||
|
metadata: Dict[str, Any] = {}
|
||||||
|
|
||||||
|
# Fields for image understanding and semantic search
|
||||||
|
embedding_id: Optional[str] = None
|
||||||
|
embedding_model: Optional[str] = None
|
||||||
|
has_embedding: bool = False
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
allow_population_by_field_name = True
|
||||||
|
arbitrary_types_allowed = True
|
||||||
|
json_encoders = {
|
||||||
|
ObjectId: str
|
||||||
|
}
|
||||||
34
src/db/models/team.py
Normal file
34
src/db/models/team.py
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
from datetime import datetime
|
||||||
|
from typing import Optional, List
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from bson import ObjectId
|
||||||
|
|
||||||
|
class PyObjectId(ObjectId):
|
||||||
|
@classmethod
|
||||||
|
def __get_validators__(cls):
|
||||||
|
yield cls.validate
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def validate(cls, v):
|
||||||
|
if not ObjectId.is_valid(v):
|
||||||
|
raise ValueError('Invalid ObjectId')
|
||||||
|
return ObjectId(v)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def __modify_schema__(cls, field_schema):
|
||||||
|
field_schema.update(type='string')
|
||||||
|
|
||||||
|
class TeamModel(BaseModel):
|
||||||
|
"""Database model for a team"""
|
||||||
|
id: Optional[PyObjectId] = Field(default_factory=PyObjectId, alias="_id")
|
||||||
|
name: str
|
||||||
|
description: Optional[str] = None
|
||||||
|
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||||
|
updated_at: Optional[datetime] = None
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
allow_population_by_field_name = True
|
||||||
|
arbitrary_types_allowed = True
|
||||||
|
json_encoders = {
|
||||||
|
ObjectId: str
|
||||||
|
}
|
||||||
25
src/db/models/user.py
Normal file
25
src/db/models/user.py
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
from datetime import datetime
|
||||||
|
from typing import Optional, List
|
||||||
|
from pydantic import BaseModel, Field, EmailStr
|
||||||
|
from bson import ObjectId
|
||||||
|
|
||||||
|
from src.db.models.team import PyObjectId
|
||||||
|
|
||||||
|
class UserModel(BaseModel):
|
||||||
|
"""Database model for a user"""
|
||||||
|
id: Optional[PyObjectId] = Field(default_factory=PyObjectId, alias="_id")
|
||||||
|
email: EmailStr
|
||||||
|
name: str
|
||||||
|
team_id: PyObjectId
|
||||||
|
is_active: bool = True
|
||||||
|
is_admin: bool = False
|
||||||
|
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||||
|
updated_at: Optional[datetime] = None
|
||||||
|
last_login: Optional[datetime] = None
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
allow_population_by_field_name = True
|
||||||
|
arbitrary_types_allowed = True
|
||||||
|
json_encoders = {
|
||||||
|
ObjectId: str
|
||||||
|
}
|
||||||
0
src/db/repositories/__init__.py
Normal file
0
src/db/repositories/__init__.py
Normal file
171
src/db/repositories/api_key_repository.py
Normal file
171
src/db/repositories/api_key_repository.py
Normal file
@ -0,0 +1,171 @@
|
|||||||
|
import logging
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import List, Optional
|
||||||
|
from bson import ObjectId
|
||||||
|
|
||||||
|
from src.db import db
|
||||||
|
from src.db.models.api_key import ApiKeyModel
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class ApiKeyRepository:
|
||||||
|
"""Repository for API key operations"""
|
||||||
|
|
||||||
|
collection_name = "api_keys"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def collection(self):
|
||||||
|
return db.get_database()[self.collection_name]
|
||||||
|
|
||||||
|
async def create(self, api_key: ApiKeyModel) -> ApiKeyModel:
|
||||||
|
"""
|
||||||
|
Create a new API key
|
||||||
|
|
||||||
|
Args:
|
||||||
|
api_key: API key data
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Created API key with ID
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
result = await self.collection.insert_one(api_key.dict(by_alias=True))
|
||||||
|
created_key = await self.get_by_id(result.inserted_id)
|
||||||
|
logger.info(f"API key created: {result.inserted_id}")
|
||||||
|
return created_key
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error creating API key: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def get_by_id(self, key_id: ObjectId) -> Optional[ApiKeyModel]:
|
||||||
|
"""
|
||||||
|
Get API key by ID
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key_id: API key ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
API key if found, None otherwise
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
key = await self.collection.find_one({"_id": key_id})
|
||||||
|
if key:
|
||||||
|
return ApiKeyModel(**key)
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting API key by ID: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def get_by_hash(self, key_hash: str) -> Optional[ApiKeyModel]:
|
||||||
|
"""
|
||||||
|
Get API key by hash
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key_hash: API key hash
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
API key if found, None otherwise
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
key = await self.collection.find_one({"key_hash": key_hash})
|
||||||
|
if key:
|
||||||
|
return ApiKeyModel(**key)
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting API key by hash: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def get_by_user(self, user_id: ObjectId) -> List[ApiKeyModel]:
|
||||||
|
"""
|
||||||
|
Get API keys by user ID
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: User ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of API keys for the user
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
keys = []
|
||||||
|
cursor = self.collection.find({"user_id": user_id})
|
||||||
|
async for document in cursor:
|
||||||
|
keys.append(ApiKeyModel(**document))
|
||||||
|
return keys
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting API keys by user: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def get_by_team(self, team_id: ObjectId) -> List[ApiKeyModel]:
|
||||||
|
"""
|
||||||
|
Get API keys by team ID
|
||||||
|
|
||||||
|
Args:
|
||||||
|
team_id: Team ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of API keys for the team
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
keys = []
|
||||||
|
cursor = self.collection.find({"team_id": team_id})
|
||||||
|
async for document in cursor:
|
||||||
|
keys.append(ApiKeyModel(**document))
|
||||||
|
return keys
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting API keys by team: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def update_last_used(self, key_id: ObjectId) -> None:
|
||||||
|
"""
|
||||||
|
Update API key's last used timestamp
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key_id: API key ID
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
await self.collection.update_one(
|
||||||
|
{"_id": key_id},
|
||||||
|
{"$set": {"last_used": datetime.utcnow()}}
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error updating API key last used: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def deactivate(self, key_id: ObjectId) -> bool:
|
||||||
|
"""
|
||||||
|
Deactivate API key
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key_id: API key ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if deactivated, False otherwise
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
result = await self.collection.update_one(
|
||||||
|
{"_id": key_id},
|
||||||
|
{"$set": {"is_active": False}}
|
||||||
|
)
|
||||||
|
return result.modified_count > 0
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error deactivating API key: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def delete(self, key_id: ObjectId) -> bool:
|
||||||
|
"""
|
||||||
|
Delete API key
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key_id: API key ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if deleted, False otherwise
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
result = await self.collection.delete_one({"_id": key_id})
|
||||||
|
return result.deleted_count > 0
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error deleting API key: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
# Create a singleton repository
|
||||||
|
api_key_repository = ApiKeyRepository()
|
||||||
239
src/db/repositories/image_repository.py
Normal file
239
src/db/repositories/image_repository.py
Normal file
@ -0,0 +1,239 @@
|
|||||||
|
import logging
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import List, Optional, Dict, Any
|
||||||
|
from bson import ObjectId
|
||||||
|
|
||||||
|
from src.db import db
|
||||||
|
from src.db.models.image import ImageModel
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class ImageRepository:
|
||||||
|
"""Repository for image operations"""
|
||||||
|
|
||||||
|
collection_name = "images"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def collection(self):
|
||||||
|
return db.get_database()[self.collection_name]
|
||||||
|
|
||||||
|
async def create(self, image: ImageModel) -> ImageModel:
|
||||||
|
"""
|
||||||
|
Create a new image record
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image: Image data
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Created image with ID
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
result = await self.collection.insert_one(image.dict(by_alias=True))
|
||||||
|
created_image = await self.get_by_id(result.inserted_id)
|
||||||
|
logger.info(f"Image created: {result.inserted_id}")
|
||||||
|
return created_image
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error creating image: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def get_by_id(self, image_id: ObjectId) -> Optional[ImageModel]:
|
||||||
|
"""
|
||||||
|
Get image by ID
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_id: Image ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Image if found, None otherwise
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
image = await self.collection.find_one({"_id": image_id})
|
||||||
|
if image:
|
||||||
|
return ImageModel(**image)
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting image by ID: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def get_by_team(self, team_id: ObjectId, limit: int = 100, skip: int = 0) -> List[ImageModel]:
|
||||||
|
"""
|
||||||
|
Get images by team ID with pagination
|
||||||
|
|
||||||
|
Args:
|
||||||
|
team_id: Team ID
|
||||||
|
limit: Max number of results
|
||||||
|
skip: Number of records to skip
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of images for the team
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
images = []
|
||||||
|
cursor = self.collection.find({"team_id": team_id}).sort("upload_date", -1).skip(skip).limit(limit)
|
||||||
|
async for document in cursor:
|
||||||
|
images.append(ImageModel(**document))
|
||||||
|
return images
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting images by team: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def count_by_team(self, team_id: ObjectId) -> int:
|
||||||
|
"""
|
||||||
|
Count images by team ID
|
||||||
|
|
||||||
|
Args:
|
||||||
|
team_id: Team ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of images for the team
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return await self.collection.count_documents({"team_id": team_id})
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error counting images by team: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def get_by_uploader(self, uploader_id: ObjectId, limit: int = 100, skip: int = 0) -> List[ImageModel]:
|
||||||
|
"""
|
||||||
|
Get images by uploader ID with pagination
|
||||||
|
|
||||||
|
Args:
|
||||||
|
uploader_id: Uploader user ID
|
||||||
|
limit: Max number of results
|
||||||
|
skip: Number of records to skip
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of images uploaded by the user
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
images = []
|
||||||
|
cursor = self.collection.find({"uploader_id": uploader_id}).sort("upload_date", -1).skip(skip).limit(limit)
|
||||||
|
async for document in cursor:
|
||||||
|
images.append(ImageModel(**document))
|
||||||
|
return images
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting images by uploader: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def search_by_metadata(self, team_id: ObjectId, query: Dict[str, Any], limit: int = 100, skip: int = 0) -> List[ImageModel]:
|
||||||
|
"""
|
||||||
|
Search images by metadata
|
||||||
|
|
||||||
|
Args:
|
||||||
|
team_id: Team ID
|
||||||
|
query: Search query
|
||||||
|
limit: Max number of results
|
||||||
|
skip: Number of records to skip
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of matching images
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Ensure we only search within the team's images
|
||||||
|
search_query = {"team_id": team_id, **query}
|
||||||
|
|
||||||
|
images = []
|
||||||
|
cursor = self.collection.find(search_query).sort("upload_date", -1).skip(skip).limit(limit)
|
||||||
|
async for document in cursor:
|
||||||
|
images.append(ImageModel(**document))
|
||||||
|
return images
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error searching images by metadata: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def update(self, image_id: ObjectId, image_data: dict) -> Optional[ImageModel]:
|
||||||
|
"""
|
||||||
|
Update image
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_id: Image ID
|
||||||
|
image_data: Update data
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Updated image if found, None otherwise
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Don't allow updating _id
|
||||||
|
if "_id" in image_data:
|
||||||
|
del image_data["_id"]
|
||||||
|
|
||||||
|
result = await self.collection.update_one(
|
||||||
|
{"_id": image_id},
|
||||||
|
{"$set": image_data}
|
||||||
|
)
|
||||||
|
|
||||||
|
if result.modified_count == 0:
|
||||||
|
logger.warning(f"No image updated for ID: {image_id}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
return await self.get_by_id(image_id)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error updating image: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def update_last_accessed(self, image_id: ObjectId) -> None:
|
||||||
|
"""
|
||||||
|
Update image's last accessed timestamp
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_id: Image ID
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
await self.collection.update_one(
|
||||||
|
{"_id": image_id},
|
||||||
|
{"$set": {"last_accessed": datetime.utcnow()}}
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error updating image last accessed: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def delete(self, image_id: ObjectId) -> bool:
|
||||||
|
"""
|
||||||
|
Delete image
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_id: Image ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if image was deleted, False otherwise
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
result = await self.collection.delete_one({"_id": image_id})
|
||||||
|
return result.deleted_count > 0
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error deleting image: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def update_embedding_status(self, image_id: ObjectId, embedding_id: str, model: str) -> Optional[ImageModel]:
|
||||||
|
"""
|
||||||
|
Update image embedding status
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_id: Image ID
|
||||||
|
embedding_id: Vector DB embedding ID
|
||||||
|
model: Model used for embedding
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Updated image if found, None otherwise
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
result = await self.collection.update_one(
|
||||||
|
{"_id": image_id},
|
||||||
|
{"$set": {
|
||||||
|
"embedding_id": embedding_id,
|
||||||
|
"embedding_model": model,
|
||||||
|
"has_embedding": True
|
||||||
|
}}
|
||||||
|
)
|
||||||
|
|
||||||
|
if result.modified_count == 0:
|
||||||
|
logger.warning(f"No image updated for embedding ID: {image_id}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
return await self.get_by_id(image_id)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error updating image embedding status: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
# Create a singleton repository
|
||||||
|
image_repository = ImageRepository()
|
||||||
126
src/db/repositories/team_repository.py
Normal file
126
src/db/repositories/team_repository.py
Normal file
@ -0,0 +1,126 @@
|
|||||||
|
import logging
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import List, Optional
|
||||||
|
from bson import ObjectId
|
||||||
|
|
||||||
|
from src.db import db
|
||||||
|
from src.db.models.team import TeamModel
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class TeamRepository:
|
||||||
|
"""Repository for team operations"""
|
||||||
|
|
||||||
|
collection_name = "teams"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def collection(self):
|
||||||
|
return db.get_database()[self.collection_name]
|
||||||
|
|
||||||
|
async def create(self, team: TeamModel) -> TeamModel:
|
||||||
|
"""
|
||||||
|
Create a new team
|
||||||
|
|
||||||
|
Args:
|
||||||
|
team: Team data
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Created team with ID
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
result = await self.collection.insert_one(team.dict(by_alias=True))
|
||||||
|
created_team = await self.get_by_id(result.inserted_id)
|
||||||
|
logger.info(f"Team created: {result.inserted_id}")
|
||||||
|
return created_team
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error creating team: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def get_by_id(self, team_id: ObjectId) -> Optional[TeamModel]:
|
||||||
|
"""
|
||||||
|
Get team by ID
|
||||||
|
|
||||||
|
Args:
|
||||||
|
team_id: Team ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Team if found, None otherwise
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
team = await self.collection.find_one({"_id": team_id})
|
||||||
|
if team:
|
||||||
|
return TeamModel(**team)
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting team by ID: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def get_all(self) -> List[TeamModel]:
|
||||||
|
"""
|
||||||
|
Get all teams
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of teams
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
teams = []
|
||||||
|
cursor = self.collection.find()
|
||||||
|
async for document in cursor:
|
||||||
|
teams.append(TeamModel(**document))
|
||||||
|
return teams
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting all teams: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def update(self, team_id: ObjectId, team_data: dict) -> Optional[TeamModel]:
|
||||||
|
"""
|
||||||
|
Update team
|
||||||
|
|
||||||
|
Args:
|
||||||
|
team_id: Team ID
|
||||||
|
team_data: Update data
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Updated team if found, None otherwise
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Add updated_at timestamp
|
||||||
|
team_data["updated_at"] = datetime.utcnow()
|
||||||
|
|
||||||
|
# Don't allow updating _id
|
||||||
|
if "_id" in team_data:
|
||||||
|
del team_data["_id"]
|
||||||
|
|
||||||
|
result = await self.collection.update_one(
|
||||||
|
{"_id": team_id},
|
||||||
|
{"$set": team_data}
|
||||||
|
)
|
||||||
|
|
||||||
|
if result.modified_count == 0:
|
||||||
|
logger.warning(f"No team updated for ID: {team_id}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
return await self.get_by_id(team_id)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error updating team: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def delete(self, team_id: ObjectId) -> bool:
|
||||||
|
"""
|
||||||
|
Delete team
|
||||||
|
|
||||||
|
Args:
|
||||||
|
team_id: Team ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if team was deleted, False otherwise
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
result = await self.collection.delete_one({"_id": team_id})
|
||||||
|
return result.deleted_count > 0
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error deleting team: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
# Create a singleton repository
|
||||||
|
team_repository = TeamRepository()
|
||||||
164
src/db/repositories/user_repository.py
Normal file
164
src/db/repositories/user_repository.py
Normal file
@ -0,0 +1,164 @@
|
|||||||
|
import logging
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import List, Optional
|
||||||
|
from bson import ObjectId
|
||||||
|
|
||||||
|
from src.db import db
|
||||||
|
from src.db.models.user import UserModel
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class UserRepository:
|
||||||
|
"""Repository for user operations"""
|
||||||
|
|
||||||
|
collection_name = "users"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def collection(self):
|
||||||
|
return db.get_database()[self.collection_name]
|
||||||
|
|
||||||
|
async def create(self, user: UserModel) -> UserModel:
|
||||||
|
"""
|
||||||
|
Create a new user
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user: User data
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Created user with ID
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
result = await self.collection.insert_one(user.dict(by_alias=True))
|
||||||
|
created_user = await self.get_by_id(result.inserted_id)
|
||||||
|
logger.info(f"User created: {result.inserted_id}")
|
||||||
|
return created_user
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error creating user: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def get_by_id(self, user_id: ObjectId) -> Optional[UserModel]:
|
||||||
|
"""
|
||||||
|
Get user by ID
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: User ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
User if found, None otherwise
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
user = await self.collection.find_one({"_id": user_id})
|
||||||
|
if user:
|
||||||
|
return UserModel(**user)
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting user by ID: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def get_by_email(self, email: str) -> Optional[UserModel]:
|
||||||
|
"""
|
||||||
|
Get user by email
|
||||||
|
|
||||||
|
Args:
|
||||||
|
email: User email
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
User if found, None otherwise
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
user = await self.collection.find_one({"email": email})
|
||||||
|
if user:
|
||||||
|
return UserModel(**user)
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting user by email: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def get_by_team(self, team_id: ObjectId) -> List[UserModel]:
|
||||||
|
"""
|
||||||
|
Get users by team ID
|
||||||
|
|
||||||
|
Args:
|
||||||
|
team_id: Team ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of users in the team
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
users = []
|
||||||
|
cursor = self.collection.find({"team_id": team_id})
|
||||||
|
async for document in cursor:
|
||||||
|
users.append(UserModel(**document))
|
||||||
|
return users
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting users by team: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def update(self, user_id: ObjectId, user_data: dict) -> Optional[UserModel]:
|
||||||
|
"""
|
||||||
|
Update user
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: User ID
|
||||||
|
user_data: Update data
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Updated user if found, None otherwise
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Add updated_at timestamp
|
||||||
|
user_data["updated_at"] = datetime.utcnow()
|
||||||
|
|
||||||
|
# Don't allow updating _id
|
||||||
|
if "_id" in user_data:
|
||||||
|
del user_data["_id"]
|
||||||
|
|
||||||
|
result = await self.collection.update_one(
|
||||||
|
{"_id": user_id},
|
||||||
|
{"$set": user_data}
|
||||||
|
)
|
||||||
|
|
||||||
|
if result.modified_count == 0:
|
||||||
|
logger.warning(f"No user updated for ID: {user_id}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
return await self.get_by_id(user_id)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error updating user: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def delete(self, user_id: ObjectId) -> bool:
|
||||||
|
"""
|
||||||
|
Delete user
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: User ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if user was deleted, False otherwise
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
result = await self.collection.delete_one({"_id": user_id})
|
||||||
|
return result.deleted_count > 0
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error deleting user: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def update_last_login(self, user_id: ObjectId) -> None:
|
||||||
|
"""
|
||||||
|
Update user's last login timestamp
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: User ID
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
await self.collection.update_one(
|
||||||
|
{"_id": user_id},
|
||||||
|
{"$set": {"last_login": datetime.utcnow()}}
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error updating user last login: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
# Create a singleton repository
|
||||||
|
user_repository = UserRepository()
|
||||||
0
src/schemas/__init__.py
Normal file
0
src/schemas/__init__.py
Normal file
89
src/schemas/api_key.py
Normal file
89
src/schemas/api_key.py
Normal file
@ -0,0 +1,89 @@
|
|||||||
|
from typing import List, Optional
|
||||||
|
from datetime import datetime
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
class ApiKeyBase(BaseModel):
|
||||||
|
"""Base schema for API key data"""
|
||||||
|
name: str = Field(..., description="API key name", min_length=1, max_length=100)
|
||||||
|
description: Optional[str] = Field(None, description="API key description", max_length=500)
|
||||||
|
|
||||||
|
class ApiKeyCreate(ApiKeyBase):
|
||||||
|
"""Schema for creating an API key"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
class ApiKeyUpdate(BaseModel):
|
||||||
|
"""Schema for updating an API key"""
|
||||||
|
name: Optional[str] = Field(None, description="API key name", min_length=1, max_length=100)
|
||||||
|
description: Optional[str] = Field(None, description="API key description", max_length=500)
|
||||||
|
is_active: Optional[bool] = Field(None, description="Whether the API key is active")
|
||||||
|
|
||||||
|
class ApiKeyResponse(ApiKeyBase):
|
||||||
|
"""Schema for API key response"""
|
||||||
|
id: str
|
||||||
|
team_id: str
|
||||||
|
user_id: str
|
||||||
|
created_at: datetime
|
||||||
|
expiry_date: Optional[datetime] = None
|
||||||
|
last_used: Optional[datetime] = None
|
||||||
|
is_active: bool
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
from_attributes = True
|
||||||
|
schema_extra = {
|
||||||
|
"example": {
|
||||||
|
"id": "507f1f77bcf86cd799439011",
|
||||||
|
"name": "Development API Key",
|
||||||
|
"description": "Used for development purposes",
|
||||||
|
"team_id": "507f1f77bcf86cd799439022",
|
||||||
|
"user_id": "507f1f77bcf86cd799439033",
|
||||||
|
"created_at": "2023-10-20T10:00:00",
|
||||||
|
"expiry_date": "2024-10-20T10:00:00",
|
||||||
|
"last_used": None,
|
||||||
|
"is_active": True
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
class ApiKeyWithValueResponse(ApiKeyResponse):
|
||||||
|
"""Schema for API key response with the raw value"""
|
||||||
|
key: str
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
schema_extra = {
|
||||||
|
"example": {
|
||||||
|
"id": "507f1f77bcf86cd799439011",
|
||||||
|
"name": "Development API Key",
|
||||||
|
"description": "Used for development purposes",
|
||||||
|
"team_id": "507f1f77bcf86cd799439022",
|
||||||
|
"user_id": "507f1f77bcf86cd799439033",
|
||||||
|
"created_at": "2023-10-20T10:00:00",
|
||||||
|
"expiry_date": "2024-10-20T10:00:00",
|
||||||
|
"last_used": None,
|
||||||
|
"is_active": True,
|
||||||
|
"key": "abc123.xyzabc123def456"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
class ApiKeyListResponse(BaseModel):
|
||||||
|
"""Schema for API key list response"""
|
||||||
|
api_keys: List[ApiKeyResponse]
|
||||||
|
total: int
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
schema_extra = {
|
||||||
|
"example": {
|
||||||
|
"api_keys": [
|
||||||
|
{
|
||||||
|
"id": "507f1f77bcf86cd799439011",
|
||||||
|
"name": "Development API Key",
|
||||||
|
"description": "Used for development purposes",
|
||||||
|
"team_id": "507f1f77bcf86cd799439022",
|
||||||
|
"user_id": "507f1f77bcf86cd799439033",
|
||||||
|
"created_at": "2023-10-20T10:00:00",
|
||||||
|
"expiry_date": "2024-10-20T10:00:00",
|
||||||
|
"last_used": None,
|
||||||
|
"is_active": True
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"total": 1
|
||||||
|
}
|
||||||
|
}
|
||||||
183
src/schemas/image.py
Normal file
183
src/schemas/image.py
Normal file
@ -0,0 +1,183 @@
|
|||||||
|
from typing import List, Optional, Dict, Any
|
||||||
|
from datetime import datetime
|
||||||
|
from pydantic import BaseModel, Field, HttpUrl
|
||||||
|
|
||||||
|
class ImageBase(BaseModel):
|
||||||
|
"""Base schema for image data"""
|
||||||
|
description: Optional[str] = Field(None, description="Image description", max_length=500)
|
||||||
|
tags: List[str] = Field(default=[], description="Image tags")
|
||||||
|
|
||||||
|
class ImageUpload(ImageBase):
|
||||||
|
"""Schema for uploading an image"""
|
||||||
|
# Note: The file itself is handled by FastAPI's UploadFile
|
||||||
|
pass
|
||||||
|
|
||||||
|
class ImageUpdate(BaseModel):
|
||||||
|
"""Schema for updating an image"""
|
||||||
|
description: Optional[str] = Field(None, description="Image description", max_length=500)
|
||||||
|
tags: Optional[List[str]] = Field(None, description="Image tags")
|
||||||
|
metadata: Optional[Dict[str, Any]] = Field(None, description="Image metadata")
|
||||||
|
|
||||||
|
class ImageResponse(ImageBase):
|
||||||
|
"""Schema for image response"""
|
||||||
|
id: str
|
||||||
|
filename: str
|
||||||
|
original_filename: str
|
||||||
|
file_size: int
|
||||||
|
content_type: str
|
||||||
|
public_url: Optional[HttpUrl] = None
|
||||||
|
team_id: str
|
||||||
|
uploader_id: str
|
||||||
|
upload_date: datetime
|
||||||
|
last_accessed: Optional[datetime] = None
|
||||||
|
metadata: Dict[str, Any] = Field(default={})
|
||||||
|
has_embedding: bool = False
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
from_attributes = True
|
||||||
|
schema_extra = {
|
||||||
|
"example": {
|
||||||
|
"id": "507f1f77bcf86cd799439011",
|
||||||
|
"filename": "1234567890abcdef.jpg",
|
||||||
|
"original_filename": "sunset.jpg",
|
||||||
|
"file_size": 1024000,
|
||||||
|
"content_type": "image/jpeg",
|
||||||
|
"public_url": "https://storage.googleapis.com/bucket/1234567890abcdef.jpg",
|
||||||
|
"team_id": "507f1f77bcf86cd799439022",
|
||||||
|
"uploader_id": "507f1f77bcf86cd799439033",
|
||||||
|
"upload_date": "2023-10-20T10:00:00",
|
||||||
|
"last_accessed": "2023-10-21T10:00:00",
|
||||||
|
"description": "Beautiful sunset over the mountains",
|
||||||
|
"tags": ["sunset", "mountains", "nature"],
|
||||||
|
"metadata": {
|
||||||
|
"width": 1920,
|
||||||
|
"height": 1080,
|
||||||
|
"location": "Rocky Mountains"
|
||||||
|
},
|
||||||
|
"has_embedding": True
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
class ImageListResponse(BaseModel):
|
||||||
|
"""Schema for image list response"""
|
||||||
|
images: List[ImageResponse]
|
||||||
|
total: int
|
||||||
|
page: int
|
||||||
|
page_size: int
|
||||||
|
total_pages: int
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
schema_extra = {
|
||||||
|
"example": {
|
||||||
|
"images": [
|
||||||
|
{
|
||||||
|
"id": "507f1f77bcf86cd799439011",
|
||||||
|
"filename": "1234567890abcdef.jpg",
|
||||||
|
"original_filename": "sunset.jpg",
|
||||||
|
"file_size": 1024000,
|
||||||
|
"content_type": "image/jpeg",
|
||||||
|
"public_url": "https://storage.googleapis.com/bucket/1234567890abcdef.jpg",
|
||||||
|
"team_id": "507f1f77bcf86cd799439022",
|
||||||
|
"uploader_id": "507f1f77bcf86cd799439033",
|
||||||
|
"upload_date": "2023-10-20T10:00:00",
|
||||||
|
"last_accessed": "2023-10-21T10:00:00",
|
||||||
|
"description": "Beautiful sunset over the mountains",
|
||||||
|
"tags": ["sunset", "mountains", "nature"],
|
||||||
|
"metadata": {
|
||||||
|
"width": 1920,
|
||||||
|
"height": 1080,
|
||||||
|
"location": "Rocky Mountains"
|
||||||
|
},
|
||||||
|
"has_embedding": True
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"total": 1,
|
||||||
|
"page": 1,
|
||||||
|
"page_size": 10,
|
||||||
|
"total_pages": 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
class ImageSearchQuery(BaseModel):
|
||||||
|
"""Schema for image search query"""
|
||||||
|
query: str = Field(..., description="Search query", min_length=1)
|
||||||
|
limit: int = Field(10, description="Maximum number of results", ge=1, le=100)
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
schema_extra = {
|
||||||
|
"example": {
|
||||||
|
"query": "mountain sunset",
|
||||||
|
"limit": 10
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
class ImageSearchResult(BaseModel):
|
||||||
|
"""Schema for image search result"""
|
||||||
|
image: ImageResponse
|
||||||
|
score: float
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
schema_extra = {
|
||||||
|
"example": {
|
||||||
|
"image": {
|
||||||
|
"id": "507f1f77bcf86cd799439011",
|
||||||
|
"filename": "1234567890abcdef.jpg",
|
||||||
|
"original_filename": "sunset.jpg",
|
||||||
|
"file_size": 1024000,
|
||||||
|
"content_type": "image/jpeg",
|
||||||
|
"public_url": "https://storage.googleapis.com/bucket/1234567890abcdef.jpg",
|
||||||
|
"team_id": "507f1f77bcf86cd799439022",
|
||||||
|
"uploader_id": "507f1f77bcf86cd799439033",
|
||||||
|
"upload_date": "2023-10-20T10:00:00",
|
||||||
|
"last_accessed": "2023-10-21T10:00:00",
|
||||||
|
"description": "Beautiful sunset over the mountains",
|
||||||
|
"tags": ["sunset", "mountains", "nature"],
|
||||||
|
"metadata": {
|
||||||
|
"width": 1920,
|
||||||
|
"height": 1080,
|
||||||
|
"location": "Rocky Mountains"
|
||||||
|
},
|
||||||
|
"has_embedding": True
|
||||||
|
},
|
||||||
|
"score": 0.95
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
class ImageSearchResponse(BaseModel):
|
||||||
|
"""Schema for image search response"""
|
||||||
|
results: List[ImageSearchResult]
|
||||||
|
total: int
|
||||||
|
query: str
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
schema_extra = {
|
||||||
|
"example": {
|
||||||
|
"results": [
|
||||||
|
{
|
||||||
|
"image": {
|
||||||
|
"id": "507f1f77bcf86cd799439011",
|
||||||
|
"filename": "1234567890abcdef.jpg",
|
||||||
|
"original_filename": "sunset.jpg",
|
||||||
|
"file_size": 1024000,
|
||||||
|
"content_type": "image/jpeg",
|
||||||
|
"public_url": "https://storage.googleapis.com/bucket/1234567890abcdef.jpg",
|
||||||
|
"team_id": "507f1f77bcf86cd799439022",
|
||||||
|
"uploader_id": "507f1f77bcf86cd799439033",
|
||||||
|
"upload_date": "2023-10-20T10:00:00",
|
||||||
|
"last_accessed": "2023-10-21T10:00:00",
|
||||||
|
"description": "Beautiful sunset over the mountains",
|
||||||
|
"tags": ["sunset", "mountains", "nature"],
|
||||||
|
"metadata": {
|
||||||
|
"width": 1920,
|
||||||
|
"height": 1080,
|
||||||
|
"location": "Rocky Mountains"
|
||||||
|
},
|
||||||
|
"has_embedding": True
|
||||||
|
},
|
||||||
|
"score": 0.95
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"total": 1,
|
||||||
|
"query": "mountain sunset"
|
||||||
|
}
|
||||||
|
}
|
||||||
56
src/schemas/team.py
Normal file
56
src/schemas/team.py
Normal file
@ -0,0 +1,56 @@
|
|||||||
|
from typing import List, Optional
|
||||||
|
from datetime import datetime
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
class TeamBase(BaseModel):
|
||||||
|
"""Base schema for team data"""
|
||||||
|
name: str = Field(..., description="Team name", min_length=1, max_length=100)
|
||||||
|
description: Optional[str] = Field(None, description="Team description", max_length=500)
|
||||||
|
|
||||||
|
class TeamCreate(TeamBase):
|
||||||
|
"""Schema for creating a team"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
class TeamUpdate(BaseModel):
|
||||||
|
"""Schema for updating a team"""
|
||||||
|
name: Optional[str] = Field(None, description="Team name", min_length=1, max_length=100)
|
||||||
|
description: Optional[str] = Field(None, description="Team description", max_length=500)
|
||||||
|
|
||||||
|
class TeamResponse(TeamBase):
|
||||||
|
"""Schema for team response"""
|
||||||
|
id: str
|
||||||
|
created_at: datetime
|
||||||
|
updated_at: Optional[datetime] = None
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
from_attributes = True
|
||||||
|
schema_extra = {
|
||||||
|
"example": {
|
||||||
|
"id": "507f1f77bcf86cd799439011",
|
||||||
|
"name": "Marketing Team",
|
||||||
|
"description": "Team responsible for marketing campaigns",
|
||||||
|
"created_at": "2023-10-20T10:00:00",
|
||||||
|
"updated_at": None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
class TeamListResponse(BaseModel):
|
||||||
|
"""Schema for team list response"""
|
||||||
|
teams: List[TeamResponse]
|
||||||
|
total: int
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
schema_extra = {
|
||||||
|
"example": {
|
||||||
|
"teams": [
|
||||||
|
{
|
||||||
|
"id": "507f1f77bcf86cd799439011",
|
||||||
|
"name": "Marketing Team",
|
||||||
|
"description": "Team responsible for marketing campaigns",
|
||||||
|
"created_at": "2023-10-20T10:00:00",
|
||||||
|
"updated_at": None
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"total": 1
|
||||||
|
}
|
||||||
|
}
|
||||||
74
src/schemas/user.py
Normal file
74
src/schemas/user.py
Normal file
@ -0,0 +1,74 @@
|
|||||||
|
from typing import List, Optional
|
||||||
|
from datetime import datetime
|
||||||
|
from pydantic import BaseModel, EmailStr, Field
|
||||||
|
|
||||||
|
class UserBase(BaseModel):
|
||||||
|
"""Base schema for user data"""
|
||||||
|
email: EmailStr = Field(..., description="User email")
|
||||||
|
name: str = Field(..., description="User name", min_length=1, max_length=100)
|
||||||
|
team_id: str = Field(..., description="Team ID")
|
||||||
|
|
||||||
|
class UserCreate(UserBase):
|
||||||
|
"""Schema for creating a user"""
|
||||||
|
is_admin: bool = Field(False, description="Whether the user is an admin")
|
||||||
|
|
||||||
|
class UserUpdate(BaseModel):
|
||||||
|
"""Schema for updating a user"""
|
||||||
|
email: Optional[EmailStr] = Field(None, description="User email")
|
||||||
|
name: Optional[str] = Field(None, description="User name", min_length=1, max_length=100)
|
||||||
|
team_id: Optional[str] = Field(None, description="Team ID")
|
||||||
|
is_active: Optional[bool] = Field(None, description="Whether the user is active")
|
||||||
|
is_admin: Optional[bool] = Field(None, description="Whether the user is an admin")
|
||||||
|
|
||||||
|
class UserResponse(BaseModel):
|
||||||
|
"""Schema for user response"""
|
||||||
|
id: str
|
||||||
|
email: EmailStr
|
||||||
|
name: str
|
||||||
|
team_id: str
|
||||||
|
is_active: bool
|
||||||
|
is_admin: bool
|
||||||
|
created_at: datetime
|
||||||
|
updated_at: Optional[datetime] = None
|
||||||
|
last_login: Optional[datetime] = None
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
from_attributes = True
|
||||||
|
schema_extra = {
|
||||||
|
"example": {
|
||||||
|
"id": "507f1f77bcf86cd799439011",
|
||||||
|
"email": "user@example.com",
|
||||||
|
"name": "John Doe",
|
||||||
|
"team_id": "507f1f77bcf86cd799439022",
|
||||||
|
"is_active": True,
|
||||||
|
"is_admin": False,
|
||||||
|
"created_at": "2023-10-20T10:00:00",
|
||||||
|
"updated_at": None,
|
||||||
|
"last_login": None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
class UserListResponse(BaseModel):
|
||||||
|
"""Schema for user list response"""
|
||||||
|
users: List[UserResponse]
|
||||||
|
total: int
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
schema_extra = {
|
||||||
|
"example": {
|
||||||
|
"users": [
|
||||||
|
{
|
||||||
|
"id": "507f1f77bcf86cd799439011",
|
||||||
|
"email": "user@example.com",
|
||||||
|
"name": "John Doe",
|
||||||
|
"team_id": "507f1f77bcf86cd799439022",
|
||||||
|
"is_active": True,
|
||||||
|
"is_admin": False,
|
||||||
|
"created_at": "2023-10-20T10:00:00",
|
||||||
|
"updated_at": None,
|
||||||
|
"last_login": None
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"total": 1
|
||||||
|
}
|
||||||
|
}
|
||||||
0
src/services/__init__.py
Normal file
0
src/services/__init__.py
Normal file
136
src/services/embedding_service.py
Normal file
136
src/services/embedding_service.py
Normal file
@ -0,0 +1,136 @@
|
|||||||
|
import io
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from typing import List, Dict, Any, Union, Optional
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
from transformers import CLIPProcessor, CLIPModel
|
||||||
|
|
||||||
|
from src.core.config import settings
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class EmbeddingService:
|
||||||
|
"""Service for generating image and text embeddings"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.model = None
|
||||||
|
self.processor = None
|
||||||
|
self.model_name = "openai/clip-vit-base-patch32"
|
||||||
|
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
self.embedding_dim = 512 # Dimension of CLIP's embeddings
|
||||||
|
|
||||||
|
def _load_model(self):
|
||||||
|
"""
|
||||||
|
Load the CLIP model if not already loaded
|
||||||
|
"""
|
||||||
|
if self.model is None:
|
||||||
|
try:
|
||||||
|
logger.info(f"Loading CLIP model on {self.device}")
|
||||||
|
self.model = CLIPModel.from_pretrained(self.model_name).to(self.device)
|
||||||
|
self.processor = CLIPProcessor.from_pretrained(self.model_name)
|
||||||
|
logger.info("CLIP model loaded successfully")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error loading CLIP model: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def generate_image_embedding(self, image_data: bytes) -> List[float]:
|
||||||
|
"""
|
||||||
|
Generate embedding for an image
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_data: Binary image data
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Image embedding as a list of floats
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
self._load_model()
|
||||||
|
|
||||||
|
# Load the image
|
||||||
|
image = Image.open(io.BytesIO(image_data))
|
||||||
|
|
||||||
|
# Process the image for the model
|
||||||
|
inputs = self.processor(
|
||||||
|
images=image,
|
||||||
|
return_tensors="pt"
|
||||||
|
).to(self.device)
|
||||||
|
|
||||||
|
# Generate the embedding
|
||||||
|
with torch.no_grad():
|
||||||
|
image_features = self.model.get_image_features(**inputs)
|
||||||
|
|
||||||
|
# Normalize the embedding
|
||||||
|
image_embedding = image_features / image_features.norm(dim=1, keepdim=True)
|
||||||
|
|
||||||
|
# Convert to list of floats
|
||||||
|
embedding = image_embedding.cpu().numpy().tolist()[0]
|
||||||
|
|
||||||
|
return embedding
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error generating image embedding: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def generate_text_embedding(self, text: str) -> List[float]:
|
||||||
|
"""
|
||||||
|
Generate embedding for a text query
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: Text query
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Text embedding as a list of floats
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
self._load_model()
|
||||||
|
|
||||||
|
# Process the text for the model
|
||||||
|
inputs = self.processor(
|
||||||
|
text=text,
|
||||||
|
return_tensors="pt",
|
||||||
|
padding=True,
|
||||||
|
truncation=True
|
||||||
|
).to(self.device)
|
||||||
|
|
||||||
|
# Generate the embedding
|
||||||
|
with torch.no_grad():
|
||||||
|
text_features = self.model.get_text_features(**inputs)
|
||||||
|
|
||||||
|
# Normalize the embedding
|
||||||
|
text_embedding = text_features / text_features.norm(dim=1, keepdim=True)
|
||||||
|
|
||||||
|
# Convert to list of floats
|
||||||
|
embedding = text_embedding.cpu().numpy().tolist()[0]
|
||||||
|
|
||||||
|
return embedding
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error generating text embedding: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def calculate_similarity(self, embedding1: List[float], embedding2: List[float]) -> float:
|
||||||
|
"""
|
||||||
|
Calculate cosine similarity between two embeddings
|
||||||
|
|
||||||
|
Args:
|
||||||
|
embedding1: First embedding
|
||||||
|
embedding2: Second embedding
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Cosine similarity (0-1)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Convert to numpy arrays
|
||||||
|
vec1 = np.array(embedding1)
|
||||||
|
vec2 = np.array(embedding2)
|
||||||
|
|
||||||
|
# Calculate cosine similarity
|
||||||
|
similarity = np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2))
|
||||||
|
|
||||||
|
return float(similarity)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error calculating similarity: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
# Create a singleton service
|
||||||
|
embedding_service = EmbeddingService()
|
||||||
143
src/services/image_processor.py
Normal file
143
src/services/image_processor.py
Normal file
@ -0,0 +1,143 @@
|
|||||||
|
import logging
|
||||||
|
import io
|
||||||
|
from typing import Dict, Any, Tuple, Optional
|
||||||
|
from PIL import Image, ImageOps
|
||||||
|
from PIL.ExifTags import TAGS
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class ImageProcessor:
|
||||||
|
"""Service for image processing operations"""
|
||||||
|
|
||||||
|
def extract_metadata(self, image_data: bytes) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Extract metadata from an image
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_data: Binary image data
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary of metadata
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
metadata = {}
|
||||||
|
|
||||||
|
# Open the image with PIL
|
||||||
|
with Image.open(io.BytesIO(image_data)) as img:
|
||||||
|
# Basic image info
|
||||||
|
metadata["width"] = img.width
|
||||||
|
metadata["height"] = img.height
|
||||||
|
metadata["format"] = img.format
|
||||||
|
metadata["mode"] = img.mode
|
||||||
|
|
||||||
|
# Try to extract EXIF data if available
|
||||||
|
if hasattr(img, '_getexif') and img._getexif():
|
||||||
|
exif = {}
|
||||||
|
for tag_id, value in img._getexif().items():
|
||||||
|
tag = TAGS.get(tag_id, tag_id)
|
||||||
|
# Skip binary data that might be in EXIF
|
||||||
|
if isinstance(value, (bytes, bytearray)):
|
||||||
|
value = "binary data"
|
||||||
|
exif[tag] = value
|
||||||
|
metadata["exif"] = exif
|
||||||
|
|
||||||
|
return metadata
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error extracting image metadata: {e}")
|
||||||
|
# Return at least an empty dict rather than failing
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def resize_image(self, image_data: bytes, max_width: int = 1200, max_height: int = 1200) -> Tuple[bytes, Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Resize an image while maintaining aspect ratio
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_data: Binary image data
|
||||||
|
max_width: Maximum width
|
||||||
|
max_height: Maximum height
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (resized_image_data, metadata)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Open the image with PIL
|
||||||
|
img = Image.open(io.BytesIO(image_data))
|
||||||
|
|
||||||
|
# Get original size
|
||||||
|
original_width, original_height = img.size
|
||||||
|
|
||||||
|
# Calculate new size
|
||||||
|
width_ratio = max_width / original_width if original_width > max_width else 1
|
||||||
|
height_ratio = max_height / original_height if original_height > max_height else 1
|
||||||
|
|
||||||
|
# Use the smaller ratio to ensure image fits within max dimensions
|
||||||
|
ratio = min(width_ratio, height_ratio)
|
||||||
|
|
||||||
|
# Only resize if the image is larger than the max dimensions
|
||||||
|
if ratio < 1:
|
||||||
|
new_width = int(original_width * ratio)
|
||||||
|
new_height = int(original_height * ratio)
|
||||||
|
img = img.resize((new_width, new_height), Image.LANCZOS)
|
||||||
|
|
||||||
|
logger.info(f"Resized image from {original_width}x{original_height} to {new_width}x{new_height}")
|
||||||
|
else:
|
||||||
|
logger.info(f"No resizing needed, image size {original_width}x{original_height}")
|
||||||
|
|
||||||
|
# Save to bytes
|
||||||
|
output = io.BytesIO()
|
||||||
|
img.save(output, format=img.format)
|
||||||
|
resized_data = output.getvalue()
|
||||||
|
|
||||||
|
# Get new metadata
|
||||||
|
metadata = {
|
||||||
|
"width": img.width,
|
||||||
|
"height": img.height,
|
||||||
|
"format": img.format,
|
||||||
|
"mode": img.mode,
|
||||||
|
"resized": ratio < 1
|
||||||
|
}
|
||||||
|
|
||||||
|
return resized_data, metadata
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error resizing image: {e}")
|
||||||
|
# Return original data on error
|
||||||
|
return image_data, {"error": str(e)}
|
||||||
|
|
||||||
|
def is_image(self, mime_type: str) -> bool:
|
||||||
|
"""
|
||||||
|
Check if a file is an image based on MIME type
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mime_type: File MIME type
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the file is an image
|
||||||
|
"""
|
||||||
|
return mime_type.startswith('image/')
|
||||||
|
|
||||||
|
def validate_image(self, image_data: bytes, mime_type: str) -> Tuple[bool, Optional[str]]:
|
||||||
|
"""
|
||||||
|
Validate an image file
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_data: Binary image data
|
||||||
|
mime_type: File MIME type
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (is_valid, error_message)
|
||||||
|
"""
|
||||||
|
if not self.is_image(mime_type):
|
||||||
|
return False, "File is not an image"
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Try opening the image with PIL
|
||||||
|
with Image.open(io.BytesIO(image_data)) as img:
|
||||||
|
# Verify image can be read
|
||||||
|
img.verify()
|
||||||
|
return True, None
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Invalid image: {e}")
|
||||||
|
return False, f"Invalid image: {str(e)}"
|
||||||
|
|
||||||
|
# Create a singleton service
|
||||||
|
image_processor = ImageProcessor()
|
||||||
227
src/services/storage.py
Normal file
227
src/services/storage.py
Normal file
@ -0,0 +1,227 @@
|
|||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from typing import Optional, BinaryIO, Tuple
|
||||||
|
from urllib.parse import quote
|
||||||
|
|
||||||
|
from fastapi import UploadFile
|
||||||
|
from google.cloud import storage
|
||||||
|
from google.oauth2 import service_account
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from src.core.config import settings
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class StorageService:
|
||||||
|
"""Service for Google Cloud Storage operations"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.bucket_name = settings.GCS_BUCKET_NAME
|
||||||
|
self.client = self._create_storage_client()
|
||||||
|
self.bucket = self._get_or_create_bucket()
|
||||||
|
|
||||||
|
def _create_storage_client(self) -> storage.Client:
|
||||||
|
"""
|
||||||
|
Create a Google Cloud Storage client
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Google Cloud Storage client
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Check if credentials file exists
|
||||||
|
if os.path.exists(settings.GCS_CREDENTIALS_FILE):
|
||||||
|
# Use credentials from file
|
||||||
|
credentials = service_account.Credentials.from_service_account_file(
|
||||||
|
settings.GCS_CREDENTIALS_FILE
|
||||||
|
)
|
||||||
|
return storage.Client(credentials=credentials)
|
||||||
|
else:
|
||||||
|
# Use default credentials (useful for Cloud Run, where credentials
|
||||||
|
# are provided by the environment)
|
||||||
|
logger.info("Using default credentials for GCS")
|
||||||
|
return storage.Client()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error creating GCS client: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def _get_or_create_bucket(self) -> storage.Bucket:
|
||||||
|
"""
|
||||||
|
Get or create a GCS bucket
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Google Cloud Storage bucket
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Check if bucket exists
|
||||||
|
bucket = self.client.bucket(self.bucket_name)
|
||||||
|
if not bucket.exists():
|
||||||
|
logger.info(f"Creating bucket: {self.bucket_name}")
|
||||||
|
bucket = self.client.create_bucket(self.bucket_name)
|
||||||
|
return bucket
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting or creating bucket: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def upload_file(self, file: UploadFile, team_id: str) -> Tuple[str, str, int, dict]:
|
||||||
|
"""
|
||||||
|
Upload a file to Google Cloud Storage
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file: File to upload
|
||||||
|
team_id: Team ID for organization
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (storage_path, content_type, file_size, metadata)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Generate a unique filename
|
||||||
|
original_filename = file.filename
|
||||||
|
extension = os.path.splitext(original_filename)[1] if original_filename else ""
|
||||||
|
unique_filename = f"{uuid.uuid4().hex}{extension}"
|
||||||
|
|
||||||
|
# Create a storage path with team ID for organization
|
||||||
|
storage_path = f"{team_id}/{unique_filename}"
|
||||||
|
|
||||||
|
# Get file content
|
||||||
|
content = await file.read()
|
||||||
|
file_size = len(content)
|
||||||
|
|
||||||
|
# Create a blob in the bucket
|
||||||
|
blob = self.bucket.blob(storage_path)
|
||||||
|
|
||||||
|
# Set content type
|
||||||
|
content_type = file.content_type
|
||||||
|
if content_type:
|
||||||
|
blob.content_type = content_type
|
||||||
|
|
||||||
|
# Set metadata
|
||||||
|
metadata = {}
|
||||||
|
try:
|
||||||
|
# Extract image metadata if it's an image
|
||||||
|
if content_type and content_type.startswith('image/'):
|
||||||
|
# Create a temporary file to read with PIL
|
||||||
|
with Image.open(BinaryIO(content)) as img:
|
||||||
|
metadata = {
|
||||||
|
'width': img.width,
|
||||||
|
'height': img.height,
|
||||||
|
'format': img.format,
|
||||||
|
'mode': img.mode
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to extract image metadata: {e}")
|
||||||
|
|
||||||
|
# Set custom metadata
|
||||||
|
blob.metadata = {
|
||||||
|
'original_filename': original_filename,
|
||||||
|
'upload_time': datetime.utcnow().isoformat(),
|
||||||
|
'team_id': team_id
|
||||||
|
}
|
||||||
|
|
||||||
|
# Upload the file
|
||||||
|
blob.upload_from_string(content, content_type=content_type)
|
||||||
|
|
||||||
|
logger.info(f"File uploaded: {storage_path}")
|
||||||
|
|
||||||
|
# Seek back to the beginning for future reads
|
||||||
|
await file.seek(0)
|
||||||
|
|
||||||
|
return storage_path, content_type, file_size, metadata
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error uploading file: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def get_file(self, storage_path: str) -> Optional[bytes]:
|
||||||
|
"""
|
||||||
|
Get a file from Google Cloud Storage
|
||||||
|
|
||||||
|
Args:
|
||||||
|
storage_path: Storage path of the file
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
File content or None if not found
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
blob = self.bucket.blob(storage_path)
|
||||||
|
if not blob.exists():
|
||||||
|
logger.warning(f"File not found: {storage_path}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
return blob.download_as_bytes()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting file: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def delete_file(self, storage_path: str) -> bool:
|
||||||
|
"""
|
||||||
|
Delete a file from Google Cloud Storage
|
||||||
|
|
||||||
|
Args:
|
||||||
|
storage_path: Storage path of the file
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if file was deleted, False if not found
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
blob = self.bucket.blob(storage_path)
|
||||||
|
if not blob.exists():
|
||||||
|
logger.warning(f"File not found for deletion: {storage_path}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
blob.delete()
|
||||||
|
logger.info(f"File deleted: {storage_path}")
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error deleting file: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def generate_public_url(self, storage_path: str) -> str:
|
||||||
|
"""
|
||||||
|
Generate a public URL for a file
|
||||||
|
|
||||||
|
Args:
|
||||||
|
storage_path: Storage path of the file
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Public URL
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
blob = self.bucket.blob(storage_path)
|
||||||
|
return blob.public_url
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error generating public URL: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def generate_signed_url(self, storage_path: str, expiration_minutes: int = 30) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Generate a signed URL for temporary access to a file
|
||||||
|
|
||||||
|
Args:
|
||||||
|
storage_path: Storage path of the file
|
||||||
|
expiration_minutes: Minutes until the URL expires
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Signed URL or None if file not found
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
blob = self.bucket.blob(storage_path)
|
||||||
|
if not blob.exists():
|
||||||
|
logger.warning(f"File not found for signed URL: {storage_path}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
expiration = datetime.utcnow() + timedelta(minutes=expiration_minutes)
|
||||||
|
|
||||||
|
url = blob.generate_signed_url(
|
||||||
|
version="v4",
|
||||||
|
expiration=expiration,
|
||||||
|
method="GET"
|
||||||
|
)
|
||||||
|
|
||||||
|
return url
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error generating signed URL: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
# Create a singleton service
|
||||||
|
storage_service = StorageService()
|
||||||
182
src/services/vector_store.py
Normal file
182
src/services/vector_store.py
Normal file
@ -0,0 +1,182 @@
|
|||||||
|
import logging
|
||||||
|
from typing import List, Dict, Any, Optional, Tuple
|
||||||
|
import pinecone
|
||||||
|
from bson import ObjectId
|
||||||
|
|
||||||
|
from src.core.config import settings
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class VectorStoreService:
|
||||||
|
"""Service for managing vector embeddings in Pinecone"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.api_key = settings.VECTOR_DB_API_KEY
|
||||||
|
self.environment = settings.VECTOR_DB_ENVIRONMENT
|
||||||
|
self.index_name = settings.VECTOR_DB_INDEX_NAME
|
||||||
|
self.dimension = 512 # CLIP model embedding dimension
|
||||||
|
self.initialized = False
|
||||||
|
self.index = None
|
||||||
|
|
||||||
|
def initialize(self):
|
||||||
|
"""
|
||||||
|
Initialize Pinecone connection and create index if needed
|
||||||
|
"""
|
||||||
|
if self.initialized:
|
||||||
|
return
|
||||||
|
|
||||||
|
if not self.api_key or not self.environment:
|
||||||
|
logger.warning("Pinecone API key or environment not provided, vector search disabled")
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
logger.info(f"Initializing Pinecone with environment {self.environment}")
|
||||||
|
|
||||||
|
# Initialize Pinecone
|
||||||
|
pinecone.init(
|
||||||
|
api_key=self.api_key,
|
||||||
|
environment=self.environment
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if index exists
|
||||||
|
if self.index_name not in pinecone.list_indexes():
|
||||||
|
logger.info(f"Creating Pinecone index: {self.index_name}")
|
||||||
|
|
||||||
|
# Create index
|
||||||
|
pinecone.create_index(
|
||||||
|
name=self.index_name,
|
||||||
|
dimension=self.dimension,
|
||||||
|
metric="cosine"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Connect to index
|
||||||
|
self.index = pinecone.Index(self.index_name)
|
||||||
|
self.initialized = True
|
||||||
|
logger.info("Pinecone initialized successfully")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error initializing Pinecone: {e}")
|
||||||
|
# Don't raise - we want to gracefully handle this and fall back to non-vector search
|
||||||
|
|
||||||
|
def store_embedding(self, image_id: str, team_id: str, embedding: List[float], metadata: Dict[str, Any] = None) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Store an embedding in Pinecone
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_id: Image ID
|
||||||
|
team_id: Team ID
|
||||||
|
embedding: Image embedding
|
||||||
|
metadata: Additional metadata
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Vector ID if successful, None otherwise
|
||||||
|
"""
|
||||||
|
self.initialize()
|
||||||
|
|
||||||
|
if not self.initialized:
|
||||||
|
logger.warning("Pinecone not initialized, cannot store embedding")
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Create metadata dict
|
||||||
|
meta = {
|
||||||
|
"image_id": image_id,
|
||||||
|
"team_id": team_id
|
||||||
|
}
|
||||||
|
|
||||||
|
if metadata:
|
||||||
|
meta.update(metadata)
|
||||||
|
|
||||||
|
# Create a unique vector ID
|
||||||
|
vector_id = f"{team_id}_{image_id}"
|
||||||
|
|
||||||
|
# Upsert the vector
|
||||||
|
self.index.upsert(
|
||||||
|
vectors=[
|
||||||
|
(vector_id, embedding, meta)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Stored embedding for image {image_id}")
|
||||||
|
return vector_id
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error storing embedding: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def search_by_embedding(self, team_id: str, query_embedding: List[float], limit: int = 10) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Search for similar images by embedding
|
||||||
|
|
||||||
|
Args:
|
||||||
|
team_id: Team ID
|
||||||
|
query_embedding: Query embedding
|
||||||
|
limit: Maximum number of results
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of results with image ID and similarity score
|
||||||
|
"""
|
||||||
|
self.initialize()
|
||||||
|
|
||||||
|
if not self.initialized:
|
||||||
|
logger.warning("Pinecone not initialized, cannot search by embedding")
|
||||||
|
return []
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Create filter for team_id
|
||||||
|
filter_dict = {
|
||||||
|
"team_id": {"$eq": team_id}
|
||||||
|
}
|
||||||
|
|
||||||
|
# Query the index
|
||||||
|
results = self.index.query(
|
||||||
|
vector=query_embedding,
|
||||||
|
filter=filter_dict,
|
||||||
|
top_k=limit,
|
||||||
|
include_metadata=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Format the results
|
||||||
|
formatted_results = []
|
||||||
|
for match in results.matches:
|
||||||
|
formatted_results.append({
|
||||||
|
"image_id": match.metadata["image_id"],
|
||||||
|
"score": match.score,
|
||||||
|
"metadata": match.metadata
|
||||||
|
})
|
||||||
|
|
||||||
|
return formatted_results
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error searching by embedding: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
def delete_embedding(self, image_id: str, team_id: str) -> bool:
|
||||||
|
"""
|
||||||
|
Delete an embedding from Pinecone
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_id: Image ID
|
||||||
|
team_id: Team ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if successful, False otherwise
|
||||||
|
"""
|
||||||
|
self.initialize()
|
||||||
|
|
||||||
|
if not self.initialized:
|
||||||
|
logger.warning("Pinecone not initialized, cannot delete embedding")
|
||||||
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Create the vector ID
|
||||||
|
vector_id = f"{team_id}_{image_id}"
|
||||||
|
|
||||||
|
# Delete the vector
|
||||||
|
self.index.delete(ids=[vector_id])
|
||||||
|
|
||||||
|
logger.info(f"Deleted embedding for image {image_id}")
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error deleting embedding: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Create a singleton service
|
||||||
|
vector_store_service = VectorStoreService()
|
||||||
Loading…
x
Reference in New Issue
Block a user