init
This commit is contained in:
parent
fd47e008a8
commit
d6dea7ef17
51
.gitignore
vendored
Normal file
51
.gitignore
vendored
Normal file
@ -0,0 +1,51 @@
|
||||
# Python
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
*.so
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
|
||||
# Virtual Environment
|
||||
venv/
|
||||
env/
|
||||
ENV/
|
||||
|
||||
# IDE
|
||||
.idea/
|
||||
.vscode/
|
||||
*.swp
|
||||
*.swo
|
||||
|
||||
# OS
|
||||
.DS_Store
|
||||
Thumbs.db
|
||||
|
||||
# Logs
|
||||
*.log
|
||||
|
||||
# Local development
|
||||
.env
|
||||
.env.local
|
||||
.env.*.local
|
||||
|
||||
# Coverage reports
|
||||
htmlcov/
|
||||
.coverage
|
||||
.coverage.*
|
||||
coverage.xml
|
||||
*.cover
|
||||
27
Dockerfile
Normal file
27
Dockerfile
Normal file
@ -0,0 +1,27 @@
|
||||
FROM python:3.9-slim
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Install system dependencies
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
build-essential \
|
||||
&& apt-get clean \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Copy requirements file
|
||||
COPY requirements.txt .
|
||||
|
||||
# Install Python dependencies
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
# Copy application code
|
||||
COPY . .
|
||||
|
||||
# Expose the port the app runs on
|
||||
EXPOSE 8000
|
||||
|
||||
# Set environment variables
|
||||
ENV PYTHONUNBUFFERED=1
|
||||
|
||||
# Command to run the application
|
||||
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"]
|
||||
156
README.md
156
README.md
@ -0,0 +1,156 @@
|
||||
# SEREACT - Secure Image Management API
|
||||
|
||||
SEREACT is a secure API for storing, organizing, and retrieving images with advanced search capabilities.
|
||||
|
||||
## Features
|
||||
|
||||
- Secure image storage in Google Cloud Storage
|
||||
- Team-based organization and access control
|
||||
- API key authentication
|
||||
- Semantic search using image embeddings
|
||||
- Metadata extraction and storage
|
||||
- Image processing capabilities
|
||||
- Multi-team support
|
||||
|
||||
## Architecture
|
||||
|
||||
```
|
||||
sereact/
|
||||
├── images/ # Sample images for testing
|
||||
├── sereact/ # Main application code
|
||||
│ ├── deployment/ # Deployment configurations
|
||||
│ │ └── cloud-run/ # Google Cloud Run configuration
|
||||
│ ├── docs/ # Documentation
|
||||
│ │ └── api/ # API documentation
|
||||
│ ├── scripts/ # Utility scripts
|
||||
│ ├── src/ # Source code
|
||||
│ │ ├── api/ # API endpoints
|
||||
│ │ │ └── v1/ # API version 1
|
||||
│ │ ├── core/ # Core modules
|
||||
│ │ ├── db/ # Database models and repositories
|
||||
│ │ │ ├── models/ # Data models
|
||||
│ │ │ └── repositories/ # Database operations
|
||||
│ │ ├── schemas/ # API schemas (request/response)
|
||||
│ │ └── services/ # Business logic services
|
||||
│ └── tests/ # Test code
|
||||
│ ├── api/ # API tests
|
||||
│ ├── db/ # Database tests
|
||||
│ └── services/ # Service tests
|
||||
├── main.py # Application entry point
|
||||
├── requirements.txt # Python dependencies
|
||||
└── README.md # This file
|
||||
```
|
||||
|
||||
## Technology Stack
|
||||
|
||||
- FastAPI - Web framework
|
||||
- MongoDB - Database
|
||||
- Google Cloud Storage - Image storage
|
||||
- Pinecone - Vector database for semantic search
|
||||
- CLIP - Image embedding model
|
||||
- PyTorch - Deep learning framework
|
||||
- Pydantic - Data validation
|
||||
|
||||
## Setup and Installation
|
||||
|
||||
### Prerequisites
|
||||
|
||||
- Python 3.8+
|
||||
- MongoDB
|
||||
- Google Cloud account with Storage enabled
|
||||
- (Optional) Pinecone account for semantic search
|
||||
|
||||
### Installation
|
||||
|
||||
1. Clone the repository:
|
||||
```bash
|
||||
git clone https://github.com/yourusername/sereact.git
|
||||
cd sereact
|
||||
```
|
||||
|
||||
2. Create and activate a virtual environment:
|
||||
```bash
|
||||
python -m venv venv
|
||||
source venv/bin/activate # Linux/macOS
|
||||
venv\Scripts\activate # Windows
|
||||
```
|
||||
|
||||
3. Install dependencies:
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
4. Create a `.env` file with the following environment variables:
|
||||
```
|
||||
# MongoDB
|
||||
DATABASE_URI=mongodb://localhost:27017
|
||||
DATABASE_NAME=imagedb
|
||||
|
||||
# Google Cloud Storage
|
||||
GCS_BUCKET_NAME=your-bucket-name
|
||||
GCS_CREDENTIALS_FILE=path/to/credentials.json
|
||||
|
||||
# Security
|
||||
API_KEY_SECRET=your-secret-key
|
||||
|
||||
# Vector database (optional)
|
||||
VECTOR_DB_API_KEY=your-pinecone-api-key
|
||||
VECTOR_DB_ENVIRONMENT=your-pinecone-environment
|
||||
VECTOR_DB_INDEX_NAME=image-embeddings
|
||||
```
|
||||
|
||||
5. Run the application:
|
||||
```bash
|
||||
uvicorn main:app --reload
|
||||
```
|
||||
|
||||
6. Visit `http://localhost:8000/docs` in your browser to access the API documentation.
|
||||
|
||||
## API Endpoints
|
||||
|
||||
The API provides the following main endpoints:
|
||||
|
||||
- `/api/v1/auth/*` - Authentication and API key management
|
||||
- `/api/v1/teams/*` - Team management
|
||||
- `/api/v1/users/*` - User management
|
||||
- `/api/v1/images/*` - Image upload, download, and management
|
||||
- `/api/v1/search/*` - Image search functionality
|
||||
|
||||
Refer to the Swagger UI documentation at `/docs` for detailed endpoint information.
|
||||
|
||||
## Development
|
||||
|
||||
### Running Tests
|
||||
|
||||
```bash
|
||||
pytest
|
||||
```
|
||||
|
||||
### Creating a New API Version
|
||||
|
||||
1. Create a new package under `src/api/` (e.g., `v2`)
|
||||
2. Implement new endpoints
|
||||
3. Update the main.py file to include the new routers
|
||||
|
||||
## Deployment
|
||||
|
||||
### Google Cloud Run
|
||||
|
||||
1. Build the Docker image:
|
||||
```bash
|
||||
docker build -t gcr.io/your-project/sereact .
|
||||
```
|
||||
|
||||
2. Push to Google Container Registry:
|
||||
```bash
|
||||
docker push gcr.io/your-project/sereact
|
||||
```
|
||||
|
||||
3. Deploy to Cloud Run:
|
||||
```bash
|
||||
gcloud run deploy sereact --image gcr.io/your-project/sereact --platform managed
|
||||
```
|
||||
|
||||
## License
|
||||
|
||||
This project is licensed under the MIT License - see the LICENSE file for details.
|
||||
41
deployment/cloud-run/service.yaml
Normal file
41
deployment/cloud-run/service.yaml
Normal file
@ -0,0 +1,41 @@
|
||||
apiVersion: serving.knative.dev/v1
|
||||
kind: Service
|
||||
metadata:
|
||||
name: sereact
|
||||
spec:
|
||||
template:
|
||||
spec:
|
||||
containers:
|
||||
- image: gcr.io/your-project/sereact:latest
|
||||
ports:
|
||||
- containerPort: 8000
|
||||
resources:
|
||||
limits:
|
||||
cpu: "1"
|
||||
memory: "1Gi"
|
||||
env:
|
||||
- name: DATABASE_URI
|
||||
valueFrom:
|
||||
secretKeyRef:
|
||||
name: sereact-db-uri
|
||||
key: latest
|
||||
- name: DATABASE_NAME
|
||||
value: "imagedb"
|
||||
- name: GCS_BUCKET_NAME
|
||||
value: "your-bucket-name"
|
||||
- name: API_KEY_SECRET
|
||||
valueFrom:
|
||||
secretKeyRef:
|
||||
name: sereact-api-key-secret
|
||||
key: latest
|
||||
- name: VECTOR_DB_API_KEY
|
||||
valueFrom:
|
||||
secretKeyRef:
|
||||
name: sereact-vector-db-key
|
||||
key: latest
|
||||
- name: VECTOR_DB_ENVIRONMENT
|
||||
value: "your-pinecone-env"
|
||||
- name: VECTOR_DB_INDEX_NAME
|
||||
value: "image-embeddings"
|
||||
- name: LOG_LEVEL
|
||||
value: "INFO"
|
||||
104
main.py
Normal file
104
main.py
Normal file
@ -0,0 +1,104 @@
|
||||
import logging
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi.openapi.docs import get_swagger_ui_html
|
||||
from fastapi.openapi.utils import get_openapi
|
||||
|
||||
# Import API routers
|
||||
from src.api.v1 import teams, users, images, auth, search
|
||||
|
||||
# Import configuration
|
||||
from src.core.config import settings
|
||||
from src.core.logging import setup_logging
|
||||
|
||||
# Setup logging
|
||||
setup_logging()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Create FastAPI app
|
||||
app = FastAPI(
|
||||
title=settings.PROJECT_NAME,
|
||||
description="API for securely storing, organizing, and retrieving images",
|
||||
version="1.0.0",
|
||||
docs_url=None,
|
||||
redoc_url=None,
|
||||
openapi_url="/api/v1/openapi.json"
|
||||
)
|
||||
|
||||
# Set up CORS
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=settings.CORS_ORIGINS,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Include API routers
|
||||
app.include_router(auth.router, prefix="/api/v1")
|
||||
app.include_router(teams.router, prefix="/api/v1")
|
||||
app.include_router(users.router, prefix="/api/v1")
|
||||
app.include_router(images.router, prefix="/api/v1")
|
||||
app.include_router(search.router, prefix="/api/v1")
|
||||
|
||||
# Custom exception handler
|
||||
@app.exception_handler(Exception)
|
||||
async def general_exception_handler(request: Request, exc: Exception):
|
||||
logger.error(f"Unhandled exception: {exc}", exc_info=True)
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={"detail": "Internal server error"}
|
||||
)
|
||||
|
||||
# Custom Swagger UI with API key authentication
|
||||
@app.get("/docs", include_in_schema=False)
|
||||
async def custom_swagger_ui_html():
|
||||
return get_swagger_ui_html(
|
||||
openapi_url=app.openapi_url,
|
||||
title=f"{app.title} - Swagger UI",
|
||||
oauth2_redirect_url=app.swagger_ui_oauth2_redirect_url,
|
||||
swagger_js_url="https://cdn.jsdelivr.net/npm/swagger-ui-dist@5.9.0/swagger-ui-bundle.js",
|
||||
swagger_css_url="https://cdn.jsdelivr.net/npm/swagger-ui-dist@5.9.0/swagger-ui.css",
|
||||
)
|
||||
|
||||
# Custom OpenAPI schema to include API key auth
|
||||
def custom_openapi():
|
||||
if app.openapi_schema:
|
||||
return app.openapi_schema
|
||||
|
||||
openapi_schema = get_openapi(
|
||||
title=app.title,
|
||||
version=app.version,
|
||||
description=app.description,
|
||||
routes=app.routes,
|
||||
)
|
||||
|
||||
# Add API key security scheme
|
||||
openapi_schema["components"] = {
|
||||
"securitySchemes": {
|
||||
"ApiKeyAuth": {
|
||||
"type": "apiKey",
|
||||
"in": "header",
|
||||
"name": "X-API-Key"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# Apply security to all endpoints except auth endpoints
|
||||
for path in openapi_schema["paths"]:
|
||||
if not path.startswith("/api/v1/auth"):
|
||||
openapi_schema["paths"][path]["security"] = [{"ApiKeyAuth": []}]
|
||||
|
||||
app.openapi_schema = openapi_schema
|
||||
return app.openapi_schema
|
||||
|
||||
app.openapi = custom_openapi
|
||||
|
||||
@app.get("/", include_in_schema=False)
|
||||
async def root():
|
||||
return {"message": "Welcome to the Image Management API. Please see /docs for API documentation."}
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True)
|
||||
21
requirements.txt
Normal file
21
requirements.txt
Normal file
@ -0,0 +1,21 @@
|
||||
fastapi==0.104.1
|
||||
uvicorn==0.23.2
|
||||
pydantic==2.4.2
|
||||
pydantic-settings==2.0.3
|
||||
python-dotenv==1.0.0
|
||||
google-cloud-storage==2.12.0
|
||||
google-cloud-vision==3.4.5
|
||||
pymongo==4.5.0
|
||||
motor==3.3.1
|
||||
python-multipart==0.0.6
|
||||
python-jose==3.3.0
|
||||
passlib==1.7.4
|
||||
tenacity==8.2.3
|
||||
pytest==7.4.3
|
||||
httpx==0.25.1
|
||||
pinecone-client==2.2.4
|
||||
pillow==10.1.0
|
||||
clip==0.2.0
|
||||
torch==2.1.0
|
||||
transformers==4.35.0
|
||||
python-slugify==8.0.1
|
||||
0
src/__init__.py
Normal file
0
src/__init__.py
Normal file
0
src/api/__init__.py
Normal file
0
src/api/__init__.py
Normal file
0
src/api/v1/__init__.py
Normal file
0
src/api/v1/__init__.py
Normal file
186
src/api/v1/auth.py
Normal file
186
src/api/v1/auth.py
Normal file
@ -0,0 +1,186 @@
|
||||
import logging
|
||||
from typing import Optional
|
||||
from datetime import datetime
|
||||
from fastapi import APIRouter, Depends, HTTPException, Header, Request
|
||||
from bson import ObjectId
|
||||
|
||||
from src.db.repositories.api_key_repository import api_key_repository
|
||||
from src.db.repositories.user_repository import user_repository
|
||||
from src.db.repositories.team_repository import team_repository
|
||||
from src.schemas.api_key import ApiKeyCreate, ApiKeyResponse, ApiKeyWithValueResponse, ApiKeyListResponse
|
||||
from src.core.security import generate_api_key, verify_api_key, calculate_expiry_date, is_expired, hash_api_key
|
||||
from src.db.models.api_key import ApiKeyModel
|
||||
from src.core.logging import log_request
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(tags=["Authentication"], prefix="/auth")
|
||||
|
||||
async def get_current_user(x_api_key: Optional[str] = Header(None)):
|
||||
"""
|
||||
Get the current user from API key
|
||||
"""
|
||||
if not x_api_key:
|
||||
raise HTTPException(status_code=401, detail="API key is required")
|
||||
|
||||
# Hash the API key
|
||||
hashed_key = hash_api_key(x_api_key)
|
||||
|
||||
# Get the key from the database
|
||||
api_key = await api_key_repository.get_by_hash(hashed_key)
|
||||
if not api_key:
|
||||
raise HTTPException(status_code=401, detail="Invalid API key")
|
||||
|
||||
# Check if the key is active
|
||||
if not api_key.is_active:
|
||||
raise HTTPException(status_code=401, detail="API key is inactive")
|
||||
|
||||
# Check if the key has expired
|
||||
if api_key.expiry_date and is_expired(api_key.expiry_date):
|
||||
raise HTTPException(status_code=401, detail="API key has expired")
|
||||
|
||||
# Get the user
|
||||
user = await user_repository.get_by_id(api_key.user_id)
|
||||
if not user:
|
||||
raise HTTPException(status_code=401, detail="User not found")
|
||||
|
||||
# Check if the user is active
|
||||
if not user.is_active:
|
||||
raise HTTPException(status_code=401, detail="User is inactive")
|
||||
|
||||
# Update last used timestamp
|
||||
await api_key_repository.update_last_used(api_key.id)
|
||||
|
||||
return user
|
||||
|
||||
@router.post("/api-keys", response_model=ApiKeyWithValueResponse, status_code=201)
|
||||
async def create_api_key(key_data: ApiKeyCreate, request: Request, current_user = Depends(get_current_user)):
|
||||
"""
|
||||
Create a new API key
|
||||
"""
|
||||
log_request(
|
||||
{"path": request.url.path, "method": request.method, "key_data": key_data.dict()},
|
||||
user_id=str(current_user.id),
|
||||
team_id=str(current_user.team_id)
|
||||
)
|
||||
|
||||
# Check if user's team exists
|
||||
team = await team_repository.get_by_id(current_user.team_id)
|
||||
if not team:
|
||||
raise HTTPException(status_code=404, detail="Team not found")
|
||||
|
||||
# Generate API key with expiry date
|
||||
raw_key, hashed_key = generate_api_key(str(current_user.team_id), str(current_user.id))
|
||||
expiry_date = calculate_expiry_date()
|
||||
|
||||
# Create API key in database
|
||||
api_key = ApiKeyModel(
|
||||
key_hash=hashed_key,
|
||||
user_id=current_user.id,
|
||||
team_id=current_user.team_id,
|
||||
name=key_data.name,
|
||||
description=key_data.description,
|
||||
expiry_date=expiry_date,
|
||||
is_active=True
|
||||
)
|
||||
|
||||
created_key = await api_key_repository.create(api_key)
|
||||
|
||||
# Convert to response model
|
||||
response = ApiKeyWithValueResponse(
|
||||
id=str(created_key.id),
|
||||
key=raw_key,
|
||||
name=created_key.name,
|
||||
description=created_key.description,
|
||||
team_id=str(created_key.team_id),
|
||||
user_id=str(created_key.user_id),
|
||||
created_at=created_key.created_at,
|
||||
expiry_date=created_key.expiry_date,
|
||||
last_used=created_key.last_used,
|
||||
is_active=created_key.is_active
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
@router.get("/api-keys", response_model=ApiKeyListResponse)
|
||||
async def list_api_keys(request: Request, current_user = Depends(get_current_user)):
|
||||
"""
|
||||
List API keys for the current user
|
||||
"""
|
||||
log_request(
|
||||
{"path": request.url.path, "method": request.method},
|
||||
user_id=str(current_user.id),
|
||||
team_id=str(current_user.team_id)
|
||||
)
|
||||
|
||||
# Get API keys for user
|
||||
keys = await api_key_repository.get_by_user(current_user.id)
|
||||
|
||||
# Convert to response models
|
||||
response_keys = []
|
||||
for key in keys:
|
||||
response_keys.append(ApiKeyResponse(
|
||||
id=str(key.id),
|
||||
name=key.name,
|
||||
description=key.description,
|
||||
team_id=str(key.team_id),
|
||||
user_id=str(key.user_id),
|
||||
created_at=key.created_at,
|
||||
expiry_date=key.expiry_date,
|
||||
last_used=key.last_used,
|
||||
is_active=key.is_active
|
||||
))
|
||||
|
||||
return ApiKeyListResponse(api_keys=response_keys, total=len(response_keys))
|
||||
|
||||
@router.delete("/api-keys/{key_id}", status_code=204)
|
||||
async def revoke_api_key(key_id: str, request: Request, current_user = Depends(get_current_user)):
|
||||
"""
|
||||
Revoke (deactivate) an API key
|
||||
"""
|
||||
log_request(
|
||||
{"path": request.url.path, "method": request.method, "key_id": key_id},
|
||||
user_id=str(current_user.id),
|
||||
team_id=str(current_user.team_id)
|
||||
)
|
||||
|
||||
try:
|
||||
# Convert string ID to ObjectId
|
||||
obj_id = ObjectId(key_id)
|
||||
except:
|
||||
raise HTTPException(status_code=400, detail="Invalid key ID")
|
||||
|
||||
# Get the API key
|
||||
key = await api_key_repository.get_by_id(obj_id)
|
||||
if not key:
|
||||
raise HTTPException(status_code=404, detail="API key not found")
|
||||
|
||||
# Check if user owns the key or is an admin
|
||||
if key.user_id != current_user.id and not current_user.is_admin:
|
||||
raise HTTPException(status_code=403, detail="Not authorized to revoke this API key")
|
||||
|
||||
# Deactivate the key
|
||||
result = await api_key_repository.deactivate(obj_id)
|
||||
if not result:
|
||||
raise HTTPException(status_code=500, detail="Failed to revoke API key")
|
||||
|
||||
return None
|
||||
|
||||
@router.get("/verify", status_code=200)
|
||||
async def verify_authentication(request: Request, current_user = Depends(get_current_user)):
|
||||
"""
|
||||
Verify the current authentication (API key)
|
||||
"""
|
||||
log_request(
|
||||
{"path": request.url.path, "method": request.method},
|
||||
user_id=str(current_user.id),
|
||||
team_id=str(current_user.team_id)
|
||||
)
|
||||
|
||||
return {
|
||||
"user_id": str(current_user.id),
|
||||
"name": current_user.name,
|
||||
"email": current_user.email,
|
||||
"team_id": str(current_user.team_id),
|
||||
"is_admin": current_user.is_admin
|
||||
}
|
||||
215
src/api/v1/teams.py
Normal file
215
src/api/v1/teams.py
Normal file
@ -0,0 +1,215 @@
|
||||
import logging
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from bson import ObjectId
|
||||
|
||||
from src.db.repositories.team_repository import team_repository
|
||||
from src.schemas.team import TeamCreate, TeamUpdate, TeamResponse, TeamListResponse
|
||||
from src.db.models.team import TeamModel
|
||||
from src.api.v1.auth import get_current_user
|
||||
from src.core.logging import log_request
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(tags=["Teams"], prefix="/teams")
|
||||
|
||||
@router.post("", response_model=TeamResponse, status_code=201)
|
||||
async def create_team(team_data: TeamCreate, request: Request, current_user = Depends(get_current_user)):
|
||||
"""
|
||||
Create a new team
|
||||
|
||||
This endpoint requires admin privileges
|
||||
"""
|
||||
log_request(
|
||||
{"path": request.url.path, "method": request.method, "team_data": team_data.dict()},
|
||||
user_id=str(current_user.id),
|
||||
team_id=str(current_user.team_id)
|
||||
)
|
||||
|
||||
# Only admins can create teams
|
||||
if not current_user.is_admin:
|
||||
raise HTTPException(status_code=403, detail="Only admins can create teams")
|
||||
|
||||
# Create team
|
||||
team = TeamModel(
|
||||
name=team_data.name,
|
||||
description=team_data.description
|
||||
)
|
||||
|
||||
created_team = await team_repository.create(team)
|
||||
|
||||
# Convert to response model
|
||||
response = TeamResponse(
|
||||
id=str(created_team.id),
|
||||
name=created_team.name,
|
||||
description=created_team.description,
|
||||
created_at=created_team.created_at,
|
||||
updated_at=created_team.updated_at
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
@router.get("", response_model=TeamListResponse)
|
||||
async def list_teams(request: Request, current_user = Depends(get_current_user)):
|
||||
"""
|
||||
List all teams
|
||||
|
||||
This endpoint requires admin privileges
|
||||
"""
|
||||
log_request(
|
||||
{"path": request.url.path, "method": request.method},
|
||||
user_id=str(current_user.id),
|
||||
team_id=str(current_user.team_id)
|
||||
)
|
||||
|
||||
# Only admins can list all teams
|
||||
if not current_user.is_admin:
|
||||
raise HTTPException(status_code=403, detail="Only admins can list all teams")
|
||||
|
||||
# Get all teams
|
||||
teams = await team_repository.get_all()
|
||||
|
||||
# Convert to response models
|
||||
response_teams = []
|
||||
for team in teams:
|
||||
response_teams.append(TeamResponse(
|
||||
id=str(team.id),
|
||||
name=team.name,
|
||||
description=team.description,
|
||||
created_at=team.created_at,
|
||||
updated_at=team.updated_at
|
||||
))
|
||||
|
||||
return TeamListResponse(teams=response_teams, total=len(response_teams))
|
||||
|
||||
@router.get("/{team_id}", response_model=TeamResponse)
|
||||
async def get_team(team_id: str, request: Request, current_user = Depends(get_current_user)):
|
||||
"""
|
||||
Get a team by ID
|
||||
|
||||
Users can only access their own team unless they are an admin
|
||||
"""
|
||||
log_request(
|
||||
{"path": request.url.path, "method": request.method, "team_id": team_id},
|
||||
user_id=str(current_user.id),
|
||||
team_id=str(current_user.team_id)
|
||||
)
|
||||
|
||||
try:
|
||||
# Convert string ID to ObjectId
|
||||
obj_id = ObjectId(team_id)
|
||||
except:
|
||||
raise HTTPException(status_code=400, detail="Invalid team ID")
|
||||
|
||||
# Get the team
|
||||
team = await team_repository.get_by_id(obj_id)
|
||||
if not team:
|
||||
raise HTTPException(status_code=404, detail="Team not found")
|
||||
|
||||
# Check if user can access this team
|
||||
if str(team.id) != str(current_user.team_id) and not current_user.is_admin:
|
||||
raise HTTPException(status_code=403, detail="Not authorized to access this team")
|
||||
|
||||
# Convert to response model
|
||||
response = TeamResponse(
|
||||
id=str(team.id),
|
||||
name=team.name,
|
||||
description=team.description,
|
||||
created_at=team.created_at,
|
||||
updated_at=team.updated_at
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
@router.put("/{team_id}", response_model=TeamResponse)
|
||||
async def update_team(team_id: str, team_data: TeamUpdate, request: Request, current_user = Depends(get_current_user)):
|
||||
"""
|
||||
Update a team
|
||||
|
||||
This endpoint requires admin privileges
|
||||
"""
|
||||
log_request(
|
||||
{"path": request.url.path, "method": request.method, "team_id": team_id, "team_data": team_data.dict()},
|
||||
user_id=str(current_user.id),
|
||||
team_id=str(current_user.team_id)
|
||||
)
|
||||
|
||||
# Only admins can update teams
|
||||
if not current_user.is_admin:
|
||||
raise HTTPException(status_code=403, detail="Only admins can update teams")
|
||||
|
||||
try:
|
||||
# Convert string ID to ObjectId
|
||||
obj_id = ObjectId(team_id)
|
||||
except:
|
||||
raise HTTPException(status_code=400, detail="Invalid team ID")
|
||||
|
||||
# Get the team
|
||||
team = await team_repository.get_by_id(obj_id)
|
||||
if not team:
|
||||
raise HTTPException(status_code=404, detail="Team not found")
|
||||
|
||||
# Update the team
|
||||
update_data = team_data.dict(exclude_unset=True)
|
||||
if not update_data:
|
||||
# No fields to update
|
||||
return TeamResponse(
|
||||
id=str(team.id),
|
||||
name=team.name,
|
||||
description=team.description,
|
||||
created_at=team.created_at,
|
||||
updated_at=team.updated_at
|
||||
)
|
||||
|
||||
updated_team = await team_repository.update(obj_id, update_data)
|
||||
if not updated_team:
|
||||
raise HTTPException(status_code=500, detail="Failed to update team")
|
||||
|
||||
# Convert to response model
|
||||
response = TeamResponse(
|
||||
id=str(updated_team.id),
|
||||
name=updated_team.name,
|
||||
description=updated_team.description,
|
||||
created_at=updated_team.created_at,
|
||||
updated_at=updated_team.updated_at
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
@router.delete("/{team_id}", status_code=204)
|
||||
async def delete_team(team_id: str, request: Request, current_user = Depends(get_current_user)):
|
||||
"""
|
||||
Delete a team
|
||||
|
||||
This endpoint requires admin privileges
|
||||
"""
|
||||
log_request(
|
||||
{"path": request.url.path, "method": request.method, "team_id": team_id},
|
||||
user_id=str(current_user.id),
|
||||
team_id=str(current_user.team_id)
|
||||
)
|
||||
|
||||
# Only admins can delete teams
|
||||
if not current_user.is_admin:
|
||||
raise HTTPException(status_code=403, detail="Only admins can delete teams")
|
||||
|
||||
try:
|
||||
# Convert string ID to ObjectId
|
||||
obj_id = ObjectId(team_id)
|
||||
except:
|
||||
raise HTTPException(status_code=400, detail="Invalid team ID")
|
||||
|
||||
# Check if team exists
|
||||
team = await team_repository.get_by_id(obj_id)
|
||||
if not team:
|
||||
raise HTTPException(status_code=404, detail="Team not found")
|
||||
|
||||
# Don't allow deleting a user's own team
|
||||
if str(team.id) == str(current_user.team_id):
|
||||
raise HTTPException(status_code=400, detail="Cannot delete your own team")
|
||||
|
||||
# Delete the team
|
||||
result = await team_repository.delete(obj_id)
|
||||
if not result:
|
||||
raise HTTPException(status_code=500, detail="Failed to delete team")
|
||||
|
||||
return None
|
||||
0
src/core/__init__.py
Normal file
0
src/core/__init__.py
Normal file
47
src/core/config.py
Normal file
47
src/core/config.py
Normal file
@ -0,0 +1,47 @@
|
||||
import os
|
||||
from typing import List
|
||||
from pydantic_settings import BaseSettings
|
||||
from pydantic import AnyHttpUrl, validator
|
||||
|
||||
class Settings(BaseSettings):
|
||||
# Project settings
|
||||
PROJECT_NAME: str = "Image Management API"
|
||||
API_V1_STR: str = "/api/v1"
|
||||
|
||||
# CORS settings
|
||||
CORS_ORIGINS: List[str] = ["*"]
|
||||
|
||||
@validator("CORS_ORIGINS", pre=True)
|
||||
def assemble_cors_origins(cls, v):
|
||||
if isinstance(v, str) and not v.startswith("["):
|
||||
return [i.strip() for i in v.split(",")]
|
||||
return v
|
||||
|
||||
# Database settings
|
||||
DATABASE_URI: str = os.getenv("DATABASE_URI", "mongodb://localhost:27017")
|
||||
DATABASE_NAME: str = os.getenv("DATABASE_NAME", "imagedb")
|
||||
|
||||
# Google Cloud Storage settings
|
||||
GCS_BUCKET_NAME: str = os.getenv("GCS_BUCKET_NAME", "image-mgmt-bucket")
|
||||
GCS_CREDENTIALS_FILE: str = os.getenv("GCS_CREDENTIALS_FILE", "credentials.json")
|
||||
|
||||
# Security settings
|
||||
API_KEY_SECRET: str = os.getenv("API_KEY_SECRET", "super-secret-key-for-development-only")
|
||||
API_KEY_EXPIRY_DAYS: int = int(os.getenv("API_KEY_EXPIRY_DAYS", "365"))
|
||||
|
||||
# Vector Database settings (for image embeddings)
|
||||
VECTOR_DB_API_KEY: str = os.getenv("VECTOR_DB_API_KEY", "")
|
||||
VECTOR_DB_ENVIRONMENT: str = os.getenv("VECTOR_DB_ENVIRONMENT", "")
|
||||
VECTOR_DB_INDEX_NAME: str = os.getenv("VECTOR_DB_INDEX_NAME", "image-embeddings")
|
||||
|
||||
# Rate limiting
|
||||
RATE_LIMIT_PER_MINUTE: int = int(os.getenv("RATE_LIMIT_PER_MINUTE", "100"))
|
||||
|
||||
# Logging
|
||||
LOG_LEVEL: str = os.getenv("LOG_LEVEL", "INFO")
|
||||
|
||||
class Config:
|
||||
case_sensitive = True
|
||||
env_file = ".env"
|
||||
|
||||
settings = Settings()
|
||||
81
src/core/logging.py
Normal file
81
src/core/logging.py
Normal file
@ -0,0 +1,81 @@
|
||||
import logging
|
||||
import sys
|
||||
from typing import Dict, Any
|
||||
|
||||
from src.core.config import settings
|
||||
|
||||
def setup_logging():
|
||||
"""Configure logging settings for the application"""
|
||||
log_format = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||
log_level = getattr(logging, settings.LOG_LEVEL.upper())
|
||||
|
||||
# Configure root logger
|
||||
logging.basicConfig(
|
||||
level=log_level,
|
||||
format=log_format,
|
||||
handlers=[
|
||||
logging.StreamHandler(sys.stdout)
|
||||
]
|
||||
)
|
||||
|
||||
# Set log levels for specific modules
|
||||
logging.getLogger("uvicorn").setLevel(logging.INFO)
|
||||
logging.getLogger("uvicorn.access").setLevel(logging.INFO)
|
||||
logging.getLogger("fastapi").setLevel(logging.INFO)
|
||||
|
||||
# Reduce noise from third-party libraries
|
||||
logging.getLogger("asyncio").setLevel(logging.WARNING)
|
||||
logging.getLogger("botocore").setLevel(logging.WARNING)
|
||||
logging.getLogger("google").setLevel(logging.WARNING)
|
||||
|
||||
# Create a logger for this module
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.debug("Logging configured successfully")
|
||||
|
||||
return logger
|
||||
|
||||
def log_request(request_data: Dict[str, Any], user_id: str = None, team_id: str = None):
|
||||
"""
|
||||
Log API request data with user and team context
|
||||
|
||||
Args:
|
||||
request_data: Dictionary with request details
|
||||
user_id: Optional user ID
|
||||
team_id: Optional team ID
|
||||
"""
|
||||
logger = logging.getLogger("api.request")
|
||||
|
||||
log_data = {
|
||||
"request": request_data,
|
||||
}
|
||||
|
||||
if user_id:
|
||||
log_data["user_id"] = user_id
|
||||
|
||||
if team_id:
|
||||
log_data["team_id"] = team_id
|
||||
|
||||
logger.info(f"API Request: {log_data}")
|
||||
|
||||
def log_error(error_message: str, error: Exception = None, context: Dict[str, Any] = None):
|
||||
"""
|
||||
Log error with context information
|
||||
|
||||
Args:
|
||||
error_message: Human-readable error message
|
||||
error: Optional exception object
|
||||
context: Optional dictionary with context data
|
||||
"""
|
||||
logger = logging.getLogger("api.error")
|
||||
|
||||
log_data = {
|
||||
"message": error_message,
|
||||
}
|
||||
|
||||
if context:
|
||||
log_data["context"] = context
|
||||
|
||||
if error:
|
||||
logger.error(f"Error: {log_data}", exc_info=error)
|
||||
else:
|
||||
logger.error(f"Error: {log_data}")
|
||||
91
src/core/security.py
Normal file
91
src/core/security.py
Normal file
@ -0,0 +1,91 @@
|
||||
import secrets
|
||||
import string
|
||||
import time
|
||||
import hmac
|
||||
import hashlib
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from src.core.config import settings
|
||||
|
||||
def generate_api_key(team_id: str, user_id: str) -> Tuple[str, str]:
|
||||
"""
|
||||
Generate a secure API key and its hashed value
|
||||
|
||||
Args:
|
||||
team_id: Team ID for which the key is generated
|
||||
user_id: User ID for which the key is generated
|
||||
|
||||
Returns:
|
||||
Tuple of (raw_api_key, hashed_api_key)
|
||||
"""
|
||||
# Generate a random key prefix (visible part)
|
||||
prefix = ''.join(secrets.choice(string.ascii_letters + string.digits) for _ in range(8))
|
||||
|
||||
# Generate a secure random token for the key
|
||||
random_part = secrets.token_hex(16)
|
||||
|
||||
# Format: prefix.random_part
|
||||
raw_api_key = f"{prefix}.{random_part}"
|
||||
|
||||
# Hash the API key for storage
|
||||
hashed_api_key = hash_api_key(raw_api_key)
|
||||
|
||||
return raw_api_key, hashed_api_key
|
||||
|
||||
def hash_api_key(api_key: str) -> str:
|
||||
"""
|
||||
Create a secure hash of the API key for storage
|
||||
|
||||
Args:
|
||||
api_key: The raw API key
|
||||
|
||||
Returns:
|
||||
Hashed API key
|
||||
"""
|
||||
return hmac.new(
|
||||
settings.API_KEY_SECRET.encode(),
|
||||
api_key.encode(),
|
||||
hashlib.sha256
|
||||
).hexdigest()
|
||||
|
||||
def verify_api_key(api_key: str, hashed_api_key: str) -> bool:
|
||||
"""
|
||||
Verify if the provided API key matches the stored hash
|
||||
|
||||
Args:
|
||||
api_key: The API key to verify
|
||||
hashed_api_key: The stored hash
|
||||
|
||||
Returns:
|
||||
True if the API key is valid
|
||||
"""
|
||||
calculated_hash = hash_api_key(api_key)
|
||||
return hmac.compare_digest(calculated_hash, hashed_api_key)
|
||||
|
||||
def calculate_expiry_date(days: Optional[int] = None) -> datetime:
|
||||
"""
|
||||
Calculate the expiry date for API keys
|
||||
|
||||
Args:
|
||||
days: Optional number of days until expiry
|
||||
|
||||
Returns:
|
||||
Expiry date
|
||||
"""
|
||||
if days is None:
|
||||
days = settings.API_KEY_EXPIRY_DAYS
|
||||
|
||||
return datetime.utcnow() + timedelta(days=days)
|
||||
|
||||
def is_expired(expiry_date: datetime) -> bool:
|
||||
"""
|
||||
Check if an API key has expired
|
||||
|
||||
Args:
|
||||
expiry_date: The expiry date
|
||||
|
||||
Returns:
|
||||
True if expired
|
||||
"""
|
||||
return datetime.utcnow() > expiry_date
|
||||
33
src/db/__init__.py
Normal file
33
src/db/__init__.py
Normal file
@ -0,0 +1,33 @@
|
||||
import logging
|
||||
from motor.motor_asyncio import AsyncIOMotorClient
|
||||
from src.core.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class Database:
|
||||
client: AsyncIOMotorClient = None
|
||||
|
||||
def connect_to_database(self):
|
||||
"""Create database connection."""
|
||||
try:
|
||||
self.client = AsyncIOMotorClient(settings.DATABASE_URI)
|
||||
logger.info("Connected to MongoDB")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to connect to MongoDB: {e}")
|
||||
raise
|
||||
|
||||
def close_database_connection(self):
|
||||
"""Close database connection."""
|
||||
try:
|
||||
if self.client:
|
||||
self.client.close()
|
||||
logger.info("Closed MongoDB connection")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to close MongoDB connection: {e}")
|
||||
|
||||
def get_database(self):
|
||||
"""Get the database instance."""
|
||||
return self.client[settings.DATABASE_NAME]
|
||||
|
||||
# Create a singleton database instance
|
||||
db = Database()
|
||||
0
src/db/models/__init__.py
Normal file
0
src/db/models/__init__.py
Normal file
26
src/db/models/api_key.py
Normal file
26
src/db/models/api_key.py
Normal file
@ -0,0 +1,26 @@
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
from pydantic import BaseModel, Field
|
||||
from bson import ObjectId
|
||||
|
||||
from src.db.models.team import PyObjectId
|
||||
|
||||
class ApiKeyModel(BaseModel):
|
||||
"""Database model for an API key"""
|
||||
id: Optional[PyObjectId] = Field(default_factory=PyObjectId, alias="_id")
|
||||
key_hash: str
|
||||
user_id: PyObjectId
|
||||
team_id: PyObjectId
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
expiry_date: Optional[datetime] = None
|
||||
last_used: Optional[datetime] = None
|
||||
is_active: bool = True
|
||||
|
||||
class Config:
|
||||
allow_population_by_field_name = True
|
||||
arbitrary_types_allowed = True
|
||||
json_encoders = {
|
||||
ObjectId: str
|
||||
}
|
||||
35
src/db/models/image.py
Normal file
35
src/db/models/image.py
Normal file
@ -0,0 +1,35 @@
|
||||
from datetime import datetime
|
||||
from typing import Optional, List, Dict, Any
|
||||
from pydantic import BaseModel, Field, HttpUrl
|
||||
from bson import ObjectId
|
||||
|
||||
from src.db.models.team import PyObjectId
|
||||
|
||||
class ImageModel(BaseModel):
|
||||
"""Database model for an image"""
|
||||
id: Optional[PyObjectId] = Field(default_factory=PyObjectId, alias="_id")
|
||||
filename: str
|
||||
original_filename: str
|
||||
file_size: int
|
||||
content_type: str
|
||||
storage_path: str
|
||||
public_url: Optional[HttpUrl] = None
|
||||
team_id: PyObjectId
|
||||
uploader_id: PyObjectId
|
||||
upload_date: datetime = Field(default_factory=datetime.utcnow)
|
||||
last_accessed: Optional[datetime] = None
|
||||
description: Optional[str] = None
|
||||
tags: List[str] = []
|
||||
metadata: Dict[str, Any] = {}
|
||||
|
||||
# Fields for image understanding and semantic search
|
||||
embedding_id: Optional[str] = None
|
||||
embedding_model: Optional[str] = None
|
||||
has_embedding: bool = False
|
||||
|
||||
class Config:
|
||||
allow_population_by_field_name = True
|
||||
arbitrary_types_allowed = True
|
||||
json_encoders = {
|
||||
ObjectId: str
|
||||
}
|
||||
34
src/db/models/team.py
Normal file
34
src/db/models/team.py
Normal file
@ -0,0 +1,34 @@
|
||||
from datetime import datetime
|
||||
from typing import Optional, List
|
||||
from pydantic import BaseModel, Field
|
||||
from bson import ObjectId
|
||||
|
||||
class PyObjectId(ObjectId):
|
||||
@classmethod
|
||||
def __get_validators__(cls):
|
||||
yield cls.validate
|
||||
|
||||
@classmethod
|
||||
def validate(cls, v):
|
||||
if not ObjectId.is_valid(v):
|
||||
raise ValueError('Invalid ObjectId')
|
||||
return ObjectId(v)
|
||||
|
||||
@classmethod
|
||||
def __modify_schema__(cls, field_schema):
|
||||
field_schema.update(type='string')
|
||||
|
||||
class TeamModel(BaseModel):
|
||||
"""Database model for a team"""
|
||||
id: Optional[PyObjectId] = Field(default_factory=PyObjectId, alias="_id")
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
updated_at: Optional[datetime] = None
|
||||
|
||||
class Config:
|
||||
allow_population_by_field_name = True
|
||||
arbitrary_types_allowed = True
|
||||
json_encoders = {
|
||||
ObjectId: str
|
||||
}
|
||||
25
src/db/models/user.py
Normal file
25
src/db/models/user.py
Normal file
@ -0,0 +1,25 @@
|
||||
from datetime import datetime
|
||||
from typing import Optional, List
|
||||
from pydantic import BaseModel, Field, EmailStr
|
||||
from bson import ObjectId
|
||||
|
||||
from src.db.models.team import PyObjectId
|
||||
|
||||
class UserModel(BaseModel):
|
||||
"""Database model for a user"""
|
||||
id: Optional[PyObjectId] = Field(default_factory=PyObjectId, alias="_id")
|
||||
email: EmailStr
|
||||
name: str
|
||||
team_id: PyObjectId
|
||||
is_active: bool = True
|
||||
is_admin: bool = False
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
updated_at: Optional[datetime] = None
|
||||
last_login: Optional[datetime] = None
|
||||
|
||||
class Config:
|
||||
allow_population_by_field_name = True
|
||||
arbitrary_types_allowed = True
|
||||
json_encoders = {
|
||||
ObjectId: str
|
||||
}
|
||||
0
src/db/repositories/__init__.py
Normal file
0
src/db/repositories/__init__.py
Normal file
171
src/db/repositories/api_key_repository.py
Normal file
171
src/db/repositories/api_key_repository.py
Normal file
@ -0,0 +1,171 @@
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
from bson import ObjectId
|
||||
|
||||
from src.db import db
|
||||
from src.db.models.api_key import ApiKeyModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class ApiKeyRepository:
|
||||
"""Repository for API key operations"""
|
||||
|
||||
collection_name = "api_keys"
|
||||
|
||||
@property
|
||||
def collection(self):
|
||||
return db.get_database()[self.collection_name]
|
||||
|
||||
async def create(self, api_key: ApiKeyModel) -> ApiKeyModel:
|
||||
"""
|
||||
Create a new API key
|
||||
|
||||
Args:
|
||||
api_key: API key data
|
||||
|
||||
Returns:
|
||||
Created API key with ID
|
||||
"""
|
||||
try:
|
||||
result = await self.collection.insert_one(api_key.dict(by_alias=True))
|
||||
created_key = await self.get_by_id(result.inserted_id)
|
||||
logger.info(f"API key created: {result.inserted_id}")
|
||||
return created_key
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating API key: {e}")
|
||||
raise
|
||||
|
||||
async def get_by_id(self, key_id: ObjectId) -> Optional[ApiKeyModel]:
|
||||
"""
|
||||
Get API key by ID
|
||||
|
||||
Args:
|
||||
key_id: API key ID
|
||||
|
||||
Returns:
|
||||
API key if found, None otherwise
|
||||
"""
|
||||
try:
|
||||
key = await self.collection.find_one({"_id": key_id})
|
||||
if key:
|
||||
return ApiKeyModel(**key)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting API key by ID: {e}")
|
||||
raise
|
||||
|
||||
async def get_by_hash(self, key_hash: str) -> Optional[ApiKeyModel]:
|
||||
"""
|
||||
Get API key by hash
|
||||
|
||||
Args:
|
||||
key_hash: API key hash
|
||||
|
||||
Returns:
|
||||
API key if found, None otherwise
|
||||
"""
|
||||
try:
|
||||
key = await self.collection.find_one({"key_hash": key_hash})
|
||||
if key:
|
||||
return ApiKeyModel(**key)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting API key by hash: {e}")
|
||||
raise
|
||||
|
||||
async def get_by_user(self, user_id: ObjectId) -> List[ApiKeyModel]:
|
||||
"""
|
||||
Get API keys by user ID
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
|
||||
Returns:
|
||||
List of API keys for the user
|
||||
"""
|
||||
try:
|
||||
keys = []
|
||||
cursor = self.collection.find({"user_id": user_id})
|
||||
async for document in cursor:
|
||||
keys.append(ApiKeyModel(**document))
|
||||
return keys
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting API keys by user: {e}")
|
||||
raise
|
||||
|
||||
async def get_by_team(self, team_id: ObjectId) -> List[ApiKeyModel]:
|
||||
"""
|
||||
Get API keys by team ID
|
||||
|
||||
Args:
|
||||
team_id: Team ID
|
||||
|
||||
Returns:
|
||||
List of API keys for the team
|
||||
"""
|
||||
try:
|
||||
keys = []
|
||||
cursor = self.collection.find({"team_id": team_id})
|
||||
async for document in cursor:
|
||||
keys.append(ApiKeyModel(**document))
|
||||
return keys
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting API keys by team: {e}")
|
||||
raise
|
||||
|
||||
async def update_last_used(self, key_id: ObjectId) -> None:
|
||||
"""
|
||||
Update API key's last used timestamp
|
||||
|
||||
Args:
|
||||
key_id: API key ID
|
||||
"""
|
||||
try:
|
||||
await self.collection.update_one(
|
||||
{"_id": key_id},
|
||||
{"$set": {"last_used": datetime.utcnow()}}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating API key last used: {e}")
|
||||
raise
|
||||
|
||||
async def deactivate(self, key_id: ObjectId) -> bool:
|
||||
"""
|
||||
Deactivate API key
|
||||
|
||||
Args:
|
||||
key_id: API key ID
|
||||
|
||||
Returns:
|
||||
True if deactivated, False otherwise
|
||||
"""
|
||||
try:
|
||||
result = await self.collection.update_one(
|
||||
{"_id": key_id},
|
||||
{"$set": {"is_active": False}}
|
||||
)
|
||||
return result.modified_count > 0
|
||||
except Exception as e:
|
||||
logger.error(f"Error deactivating API key: {e}")
|
||||
raise
|
||||
|
||||
async def delete(self, key_id: ObjectId) -> bool:
|
||||
"""
|
||||
Delete API key
|
||||
|
||||
Args:
|
||||
key_id: API key ID
|
||||
|
||||
Returns:
|
||||
True if deleted, False otherwise
|
||||
"""
|
||||
try:
|
||||
result = await self.collection.delete_one({"_id": key_id})
|
||||
return result.deleted_count > 0
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting API key: {e}")
|
||||
raise
|
||||
|
||||
# Create a singleton repository
|
||||
api_key_repository = ApiKeyRepository()
|
||||
239
src/db/repositories/image_repository.py
Normal file
239
src/db/repositories/image_repository.py
Normal file
@ -0,0 +1,239 @@
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import List, Optional, Dict, Any
|
||||
from bson import ObjectId
|
||||
|
||||
from src.db import db
|
||||
from src.db.models.image import ImageModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class ImageRepository:
|
||||
"""Repository for image operations"""
|
||||
|
||||
collection_name = "images"
|
||||
|
||||
@property
|
||||
def collection(self):
|
||||
return db.get_database()[self.collection_name]
|
||||
|
||||
async def create(self, image: ImageModel) -> ImageModel:
|
||||
"""
|
||||
Create a new image record
|
||||
|
||||
Args:
|
||||
image: Image data
|
||||
|
||||
Returns:
|
||||
Created image with ID
|
||||
"""
|
||||
try:
|
||||
result = await self.collection.insert_one(image.dict(by_alias=True))
|
||||
created_image = await self.get_by_id(result.inserted_id)
|
||||
logger.info(f"Image created: {result.inserted_id}")
|
||||
return created_image
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating image: {e}")
|
||||
raise
|
||||
|
||||
async def get_by_id(self, image_id: ObjectId) -> Optional[ImageModel]:
|
||||
"""
|
||||
Get image by ID
|
||||
|
||||
Args:
|
||||
image_id: Image ID
|
||||
|
||||
Returns:
|
||||
Image if found, None otherwise
|
||||
"""
|
||||
try:
|
||||
image = await self.collection.find_one({"_id": image_id})
|
||||
if image:
|
||||
return ImageModel(**image)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting image by ID: {e}")
|
||||
raise
|
||||
|
||||
async def get_by_team(self, team_id: ObjectId, limit: int = 100, skip: int = 0) -> List[ImageModel]:
|
||||
"""
|
||||
Get images by team ID with pagination
|
||||
|
||||
Args:
|
||||
team_id: Team ID
|
||||
limit: Max number of results
|
||||
skip: Number of records to skip
|
||||
|
||||
Returns:
|
||||
List of images for the team
|
||||
"""
|
||||
try:
|
||||
images = []
|
||||
cursor = self.collection.find({"team_id": team_id}).sort("upload_date", -1).skip(skip).limit(limit)
|
||||
async for document in cursor:
|
||||
images.append(ImageModel(**document))
|
||||
return images
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting images by team: {e}")
|
||||
raise
|
||||
|
||||
async def count_by_team(self, team_id: ObjectId) -> int:
|
||||
"""
|
||||
Count images by team ID
|
||||
|
||||
Args:
|
||||
team_id: Team ID
|
||||
|
||||
Returns:
|
||||
Number of images for the team
|
||||
"""
|
||||
try:
|
||||
return await self.collection.count_documents({"team_id": team_id})
|
||||
except Exception as e:
|
||||
logger.error(f"Error counting images by team: {e}")
|
||||
raise
|
||||
|
||||
async def get_by_uploader(self, uploader_id: ObjectId, limit: int = 100, skip: int = 0) -> List[ImageModel]:
|
||||
"""
|
||||
Get images by uploader ID with pagination
|
||||
|
||||
Args:
|
||||
uploader_id: Uploader user ID
|
||||
limit: Max number of results
|
||||
skip: Number of records to skip
|
||||
|
||||
Returns:
|
||||
List of images uploaded by the user
|
||||
"""
|
||||
try:
|
||||
images = []
|
||||
cursor = self.collection.find({"uploader_id": uploader_id}).sort("upload_date", -1).skip(skip).limit(limit)
|
||||
async for document in cursor:
|
||||
images.append(ImageModel(**document))
|
||||
return images
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting images by uploader: {e}")
|
||||
raise
|
||||
|
||||
async def search_by_metadata(self, team_id: ObjectId, query: Dict[str, Any], limit: int = 100, skip: int = 0) -> List[ImageModel]:
|
||||
"""
|
||||
Search images by metadata
|
||||
|
||||
Args:
|
||||
team_id: Team ID
|
||||
query: Search query
|
||||
limit: Max number of results
|
||||
skip: Number of records to skip
|
||||
|
||||
Returns:
|
||||
List of matching images
|
||||
"""
|
||||
try:
|
||||
# Ensure we only search within the team's images
|
||||
search_query = {"team_id": team_id, **query}
|
||||
|
||||
images = []
|
||||
cursor = self.collection.find(search_query).sort("upload_date", -1).skip(skip).limit(limit)
|
||||
async for document in cursor:
|
||||
images.append(ImageModel(**document))
|
||||
return images
|
||||
except Exception as e:
|
||||
logger.error(f"Error searching images by metadata: {e}")
|
||||
raise
|
||||
|
||||
async def update(self, image_id: ObjectId, image_data: dict) -> Optional[ImageModel]:
|
||||
"""
|
||||
Update image
|
||||
|
||||
Args:
|
||||
image_id: Image ID
|
||||
image_data: Update data
|
||||
|
||||
Returns:
|
||||
Updated image if found, None otherwise
|
||||
"""
|
||||
try:
|
||||
# Don't allow updating _id
|
||||
if "_id" in image_data:
|
||||
del image_data["_id"]
|
||||
|
||||
result = await self.collection.update_one(
|
||||
{"_id": image_id},
|
||||
{"$set": image_data}
|
||||
)
|
||||
|
||||
if result.modified_count == 0:
|
||||
logger.warning(f"No image updated for ID: {image_id}")
|
||||
return None
|
||||
|
||||
return await self.get_by_id(image_id)
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating image: {e}")
|
||||
raise
|
||||
|
||||
async def update_last_accessed(self, image_id: ObjectId) -> None:
|
||||
"""
|
||||
Update image's last accessed timestamp
|
||||
|
||||
Args:
|
||||
image_id: Image ID
|
||||
"""
|
||||
try:
|
||||
await self.collection.update_one(
|
||||
{"_id": image_id},
|
||||
{"$set": {"last_accessed": datetime.utcnow()}}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating image last accessed: {e}")
|
||||
raise
|
||||
|
||||
async def delete(self, image_id: ObjectId) -> bool:
|
||||
"""
|
||||
Delete image
|
||||
|
||||
Args:
|
||||
image_id: Image ID
|
||||
|
||||
Returns:
|
||||
True if image was deleted, False otherwise
|
||||
"""
|
||||
try:
|
||||
result = await self.collection.delete_one({"_id": image_id})
|
||||
return result.deleted_count > 0
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting image: {e}")
|
||||
raise
|
||||
|
||||
async def update_embedding_status(self, image_id: ObjectId, embedding_id: str, model: str) -> Optional[ImageModel]:
|
||||
"""
|
||||
Update image embedding status
|
||||
|
||||
Args:
|
||||
image_id: Image ID
|
||||
embedding_id: Vector DB embedding ID
|
||||
model: Model used for embedding
|
||||
|
||||
Returns:
|
||||
Updated image if found, None otherwise
|
||||
"""
|
||||
try:
|
||||
result = await self.collection.update_one(
|
||||
{"_id": image_id},
|
||||
{"$set": {
|
||||
"embedding_id": embedding_id,
|
||||
"embedding_model": model,
|
||||
"has_embedding": True
|
||||
}}
|
||||
)
|
||||
|
||||
if result.modified_count == 0:
|
||||
logger.warning(f"No image updated for embedding ID: {image_id}")
|
||||
return None
|
||||
|
||||
return await self.get_by_id(image_id)
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating image embedding status: {e}")
|
||||
raise
|
||||
|
||||
# Create a singleton repository
|
||||
image_repository = ImageRepository()
|
||||
126
src/db/repositories/team_repository.py
Normal file
126
src/db/repositories/team_repository.py
Normal file
@ -0,0 +1,126 @@
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
from bson import ObjectId
|
||||
|
||||
from src.db import db
|
||||
from src.db.models.team import TeamModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class TeamRepository:
|
||||
"""Repository for team operations"""
|
||||
|
||||
collection_name = "teams"
|
||||
|
||||
@property
|
||||
def collection(self):
|
||||
return db.get_database()[self.collection_name]
|
||||
|
||||
async def create(self, team: TeamModel) -> TeamModel:
|
||||
"""
|
||||
Create a new team
|
||||
|
||||
Args:
|
||||
team: Team data
|
||||
|
||||
Returns:
|
||||
Created team with ID
|
||||
"""
|
||||
try:
|
||||
result = await self.collection.insert_one(team.dict(by_alias=True))
|
||||
created_team = await self.get_by_id(result.inserted_id)
|
||||
logger.info(f"Team created: {result.inserted_id}")
|
||||
return created_team
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating team: {e}")
|
||||
raise
|
||||
|
||||
async def get_by_id(self, team_id: ObjectId) -> Optional[TeamModel]:
|
||||
"""
|
||||
Get team by ID
|
||||
|
||||
Args:
|
||||
team_id: Team ID
|
||||
|
||||
Returns:
|
||||
Team if found, None otherwise
|
||||
"""
|
||||
try:
|
||||
team = await self.collection.find_one({"_id": team_id})
|
||||
if team:
|
||||
return TeamModel(**team)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting team by ID: {e}")
|
||||
raise
|
||||
|
||||
async def get_all(self) -> List[TeamModel]:
|
||||
"""
|
||||
Get all teams
|
||||
|
||||
Returns:
|
||||
List of teams
|
||||
"""
|
||||
try:
|
||||
teams = []
|
||||
cursor = self.collection.find()
|
||||
async for document in cursor:
|
||||
teams.append(TeamModel(**document))
|
||||
return teams
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting all teams: {e}")
|
||||
raise
|
||||
|
||||
async def update(self, team_id: ObjectId, team_data: dict) -> Optional[TeamModel]:
|
||||
"""
|
||||
Update team
|
||||
|
||||
Args:
|
||||
team_id: Team ID
|
||||
team_data: Update data
|
||||
|
||||
Returns:
|
||||
Updated team if found, None otherwise
|
||||
"""
|
||||
try:
|
||||
# Add updated_at timestamp
|
||||
team_data["updated_at"] = datetime.utcnow()
|
||||
|
||||
# Don't allow updating _id
|
||||
if "_id" in team_data:
|
||||
del team_data["_id"]
|
||||
|
||||
result = await self.collection.update_one(
|
||||
{"_id": team_id},
|
||||
{"$set": team_data}
|
||||
)
|
||||
|
||||
if result.modified_count == 0:
|
||||
logger.warning(f"No team updated for ID: {team_id}")
|
||||
return None
|
||||
|
||||
return await self.get_by_id(team_id)
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating team: {e}")
|
||||
raise
|
||||
|
||||
async def delete(self, team_id: ObjectId) -> bool:
|
||||
"""
|
||||
Delete team
|
||||
|
||||
Args:
|
||||
team_id: Team ID
|
||||
|
||||
Returns:
|
||||
True if team was deleted, False otherwise
|
||||
"""
|
||||
try:
|
||||
result = await self.collection.delete_one({"_id": team_id})
|
||||
return result.deleted_count > 0
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting team: {e}")
|
||||
raise
|
||||
|
||||
# Create a singleton repository
|
||||
team_repository = TeamRepository()
|
||||
164
src/db/repositories/user_repository.py
Normal file
164
src/db/repositories/user_repository.py
Normal file
@ -0,0 +1,164 @@
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
from bson import ObjectId
|
||||
|
||||
from src.db import db
|
||||
from src.db.models.user import UserModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class UserRepository:
|
||||
"""Repository for user operations"""
|
||||
|
||||
collection_name = "users"
|
||||
|
||||
@property
|
||||
def collection(self):
|
||||
return db.get_database()[self.collection_name]
|
||||
|
||||
async def create(self, user: UserModel) -> UserModel:
|
||||
"""
|
||||
Create a new user
|
||||
|
||||
Args:
|
||||
user: User data
|
||||
|
||||
Returns:
|
||||
Created user with ID
|
||||
"""
|
||||
try:
|
||||
result = await self.collection.insert_one(user.dict(by_alias=True))
|
||||
created_user = await self.get_by_id(result.inserted_id)
|
||||
logger.info(f"User created: {result.inserted_id}")
|
||||
return created_user
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating user: {e}")
|
||||
raise
|
||||
|
||||
async def get_by_id(self, user_id: ObjectId) -> Optional[UserModel]:
|
||||
"""
|
||||
Get user by ID
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
|
||||
Returns:
|
||||
User if found, None otherwise
|
||||
"""
|
||||
try:
|
||||
user = await self.collection.find_one({"_id": user_id})
|
||||
if user:
|
||||
return UserModel(**user)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting user by ID: {e}")
|
||||
raise
|
||||
|
||||
async def get_by_email(self, email: str) -> Optional[UserModel]:
|
||||
"""
|
||||
Get user by email
|
||||
|
||||
Args:
|
||||
email: User email
|
||||
|
||||
Returns:
|
||||
User if found, None otherwise
|
||||
"""
|
||||
try:
|
||||
user = await self.collection.find_one({"email": email})
|
||||
if user:
|
||||
return UserModel(**user)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting user by email: {e}")
|
||||
raise
|
||||
|
||||
async def get_by_team(self, team_id: ObjectId) -> List[UserModel]:
|
||||
"""
|
||||
Get users by team ID
|
||||
|
||||
Args:
|
||||
team_id: Team ID
|
||||
|
||||
Returns:
|
||||
List of users in the team
|
||||
"""
|
||||
try:
|
||||
users = []
|
||||
cursor = self.collection.find({"team_id": team_id})
|
||||
async for document in cursor:
|
||||
users.append(UserModel(**document))
|
||||
return users
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting users by team: {e}")
|
||||
raise
|
||||
|
||||
async def update(self, user_id: ObjectId, user_data: dict) -> Optional[UserModel]:
|
||||
"""
|
||||
Update user
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
user_data: Update data
|
||||
|
||||
Returns:
|
||||
Updated user if found, None otherwise
|
||||
"""
|
||||
try:
|
||||
# Add updated_at timestamp
|
||||
user_data["updated_at"] = datetime.utcnow()
|
||||
|
||||
# Don't allow updating _id
|
||||
if "_id" in user_data:
|
||||
del user_data["_id"]
|
||||
|
||||
result = await self.collection.update_one(
|
||||
{"_id": user_id},
|
||||
{"$set": user_data}
|
||||
)
|
||||
|
||||
if result.modified_count == 0:
|
||||
logger.warning(f"No user updated for ID: {user_id}")
|
||||
return None
|
||||
|
||||
return await self.get_by_id(user_id)
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating user: {e}")
|
||||
raise
|
||||
|
||||
async def delete(self, user_id: ObjectId) -> bool:
|
||||
"""
|
||||
Delete user
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
|
||||
Returns:
|
||||
True if user was deleted, False otherwise
|
||||
"""
|
||||
try:
|
||||
result = await self.collection.delete_one({"_id": user_id})
|
||||
return result.deleted_count > 0
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting user: {e}")
|
||||
raise
|
||||
|
||||
async def update_last_login(self, user_id: ObjectId) -> None:
|
||||
"""
|
||||
Update user's last login timestamp
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
"""
|
||||
try:
|
||||
await self.collection.update_one(
|
||||
{"_id": user_id},
|
||||
{"$set": {"last_login": datetime.utcnow()}}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating user last login: {e}")
|
||||
raise
|
||||
|
||||
# Create a singleton repository
|
||||
user_repository = UserRepository()
|
||||
0
src/schemas/__init__.py
Normal file
0
src/schemas/__init__.py
Normal file
89
src/schemas/api_key.py
Normal file
89
src/schemas/api_key.py
Normal file
@ -0,0 +1,89 @@
|
||||
from typing import List, Optional
|
||||
from datetime import datetime
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
class ApiKeyBase(BaseModel):
|
||||
"""Base schema for API key data"""
|
||||
name: str = Field(..., description="API key name", min_length=1, max_length=100)
|
||||
description: Optional[str] = Field(None, description="API key description", max_length=500)
|
||||
|
||||
class ApiKeyCreate(ApiKeyBase):
|
||||
"""Schema for creating an API key"""
|
||||
pass
|
||||
|
||||
class ApiKeyUpdate(BaseModel):
|
||||
"""Schema for updating an API key"""
|
||||
name: Optional[str] = Field(None, description="API key name", min_length=1, max_length=100)
|
||||
description: Optional[str] = Field(None, description="API key description", max_length=500)
|
||||
is_active: Optional[bool] = Field(None, description="Whether the API key is active")
|
||||
|
||||
class ApiKeyResponse(ApiKeyBase):
|
||||
"""Schema for API key response"""
|
||||
id: str
|
||||
team_id: str
|
||||
user_id: str
|
||||
created_at: datetime
|
||||
expiry_date: Optional[datetime] = None
|
||||
last_used: Optional[datetime] = None
|
||||
is_active: bool
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
schema_extra = {
|
||||
"example": {
|
||||
"id": "507f1f77bcf86cd799439011",
|
||||
"name": "Development API Key",
|
||||
"description": "Used for development purposes",
|
||||
"team_id": "507f1f77bcf86cd799439022",
|
||||
"user_id": "507f1f77bcf86cd799439033",
|
||||
"created_at": "2023-10-20T10:00:00",
|
||||
"expiry_date": "2024-10-20T10:00:00",
|
||||
"last_used": None,
|
||||
"is_active": True
|
||||
}
|
||||
}
|
||||
|
||||
class ApiKeyWithValueResponse(ApiKeyResponse):
|
||||
"""Schema for API key response with the raw value"""
|
||||
key: str
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
"example": {
|
||||
"id": "507f1f77bcf86cd799439011",
|
||||
"name": "Development API Key",
|
||||
"description": "Used for development purposes",
|
||||
"team_id": "507f1f77bcf86cd799439022",
|
||||
"user_id": "507f1f77bcf86cd799439033",
|
||||
"created_at": "2023-10-20T10:00:00",
|
||||
"expiry_date": "2024-10-20T10:00:00",
|
||||
"last_used": None,
|
||||
"is_active": True,
|
||||
"key": "abc123.xyzabc123def456"
|
||||
}
|
||||
}
|
||||
|
||||
class ApiKeyListResponse(BaseModel):
|
||||
"""Schema for API key list response"""
|
||||
api_keys: List[ApiKeyResponse]
|
||||
total: int
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
"example": {
|
||||
"api_keys": [
|
||||
{
|
||||
"id": "507f1f77bcf86cd799439011",
|
||||
"name": "Development API Key",
|
||||
"description": "Used for development purposes",
|
||||
"team_id": "507f1f77bcf86cd799439022",
|
||||
"user_id": "507f1f77bcf86cd799439033",
|
||||
"created_at": "2023-10-20T10:00:00",
|
||||
"expiry_date": "2024-10-20T10:00:00",
|
||||
"last_used": None,
|
||||
"is_active": True
|
||||
}
|
||||
],
|
||||
"total": 1
|
||||
}
|
||||
}
|
||||
183
src/schemas/image.py
Normal file
183
src/schemas/image.py
Normal file
@ -0,0 +1,183 @@
|
||||
from typing import List, Optional, Dict, Any
|
||||
from datetime import datetime
|
||||
from pydantic import BaseModel, Field, HttpUrl
|
||||
|
||||
class ImageBase(BaseModel):
|
||||
"""Base schema for image data"""
|
||||
description: Optional[str] = Field(None, description="Image description", max_length=500)
|
||||
tags: List[str] = Field(default=[], description="Image tags")
|
||||
|
||||
class ImageUpload(ImageBase):
|
||||
"""Schema for uploading an image"""
|
||||
# Note: The file itself is handled by FastAPI's UploadFile
|
||||
pass
|
||||
|
||||
class ImageUpdate(BaseModel):
|
||||
"""Schema for updating an image"""
|
||||
description: Optional[str] = Field(None, description="Image description", max_length=500)
|
||||
tags: Optional[List[str]] = Field(None, description="Image tags")
|
||||
metadata: Optional[Dict[str, Any]] = Field(None, description="Image metadata")
|
||||
|
||||
class ImageResponse(ImageBase):
|
||||
"""Schema for image response"""
|
||||
id: str
|
||||
filename: str
|
||||
original_filename: str
|
||||
file_size: int
|
||||
content_type: str
|
||||
public_url: Optional[HttpUrl] = None
|
||||
team_id: str
|
||||
uploader_id: str
|
||||
upload_date: datetime
|
||||
last_accessed: Optional[datetime] = None
|
||||
metadata: Dict[str, Any] = Field(default={})
|
||||
has_embedding: bool = False
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
schema_extra = {
|
||||
"example": {
|
||||
"id": "507f1f77bcf86cd799439011",
|
||||
"filename": "1234567890abcdef.jpg",
|
||||
"original_filename": "sunset.jpg",
|
||||
"file_size": 1024000,
|
||||
"content_type": "image/jpeg",
|
||||
"public_url": "https://storage.googleapis.com/bucket/1234567890abcdef.jpg",
|
||||
"team_id": "507f1f77bcf86cd799439022",
|
||||
"uploader_id": "507f1f77bcf86cd799439033",
|
||||
"upload_date": "2023-10-20T10:00:00",
|
||||
"last_accessed": "2023-10-21T10:00:00",
|
||||
"description": "Beautiful sunset over the mountains",
|
||||
"tags": ["sunset", "mountains", "nature"],
|
||||
"metadata": {
|
||||
"width": 1920,
|
||||
"height": 1080,
|
||||
"location": "Rocky Mountains"
|
||||
},
|
||||
"has_embedding": True
|
||||
}
|
||||
}
|
||||
|
||||
class ImageListResponse(BaseModel):
|
||||
"""Schema for image list response"""
|
||||
images: List[ImageResponse]
|
||||
total: int
|
||||
page: int
|
||||
page_size: int
|
||||
total_pages: int
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
"example": {
|
||||
"images": [
|
||||
{
|
||||
"id": "507f1f77bcf86cd799439011",
|
||||
"filename": "1234567890abcdef.jpg",
|
||||
"original_filename": "sunset.jpg",
|
||||
"file_size": 1024000,
|
||||
"content_type": "image/jpeg",
|
||||
"public_url": "https://storage.googleapis.com/bucket/1234567890abcdef.jpg",
|
||||
"team_id": "507f1f77bcf86cd799439022",
|
||||
"uploader_id": "507f1f77bcf86cd799439033",
|
||||
"upload_date": "2023-10-20T10:00:00",
|
||||
"last_accessed": "2023-10-21T10:00:00",
|
||||
"description": "Beautiful sunset over the mountains",
|
||||
"tags": ["sunset", "mountains", "nature"],
|
||||
"metadata": {
|
||||
"width": 1920,
|
||||
"height": 1080,
|
||||
"location": "Rocky Mountains"
|
||||
},
|
||||
"has_embedding": True
|
||||
}
|
||||
],
|
||||
"total": 1,
|
||||
"page": 1,
|
||||
"page_size": 10,
|
||||
"total_pages": 1
|
||||
}
|
||||
}
|
||||
|
||||
class ImageSearchQuery(BaseModel):
|
||||
"""Schema for image search query"""
|
||||
query: str = Field(..., description="Search query", min_length=1)
|
||||
limit: int = Field(10, description="Maximum number of results", ge=1, le=100)
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
"example": {
|
||||
"query": "mountain sunset",
|
||||
"limit": 10
|
||||
}
|
||||
}
|
||||
|
||||
class ImageSearchResult(BaseModel):
|
||||
"""Schema for image search result"""
|
||||
image: ImageResponse
|
||||
score: float
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
"example": {
|
||||
"image": {
|
||||
"id": "507f1f77bcf86cd799439011",
|
||||
"filename": "1234567890abcdef.jpg",
|
||||
"original_filename": "sunset.jpg",
|
||||
"file_size": 1024000,
|
||||
"content_type": "image/jpeg",
|
||||
"public_url": "https://storage.googleapis.com/bucket/1234567890abcdef.jpg",
|
||||
"team_id": "507f1f77bcf86cd799439022",
|
||||
"uploader_id": "507f1f77bcf86cd799439033",
|
||||
"upload_date": "2023-10-20T10:00:00",
|
||||
"last_accessed": "2023-10-21T10:00:00",
|
||||
"description": "Beautiful sunset over the mountains",
|
||||
"tags": ["sunset", "mountains", "nature"],
|
||||
"metadata": {
|
||||
"width": 1920,
|
||||
"height": 1080,
|
||||
"location": "Rocky Mountains"
|
||||
},
|
||||
"has_embedding": True
|
||||
},
|
||||
"score": 0.95
|
||||
}
|
||||
}
|
||||
|
||||
class ImageSearchResponse(BaseModel):
|
||||
"""Schema for image search response"""
|
||||
results: List[ImageSearchResult]
|
||||
total: int
|
||||
query: str
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
"example": {
|
||||
"results": [
|
||||
{
|
||||
"image": {
|
||||
"id": "507f1f77bcf86cd799439011",
|
||||
"filename": "1234567890abcdef.jpg",
|
||||
"original_filename": "sunset.jpg",
|
||||
"file_size": 1024000,
|
||||
"content_type": "image/jpeg",
|
||||
"public_url": "https://storage.googleapis.com/bucket/1234567890abcdef.jpg",
|
||||
"team_id": "507f1f77bcf86cd799439022",
|
||||
"uploader_id": "507f1f77bcf86cd799439033",
|
||||
"upload_date": "2023-10-20T10:00:00",
|
||||
"last_accessed": "2023-10-21T10:00:00",
|
||||
"description": "Beautiful sunset over the mountains",
|
||||
"tags": ["sunset", "mountains", "nature"],
|
||||
"metadata": {
|
||||
"width": 1920,
|
||||
"height": 1080,
|
||||
"location": "Rocky Mountains"
|
||||
},
|
||||
"has_embedding": True
|
||||
},
|
||||
"score": 0.95
|
||||
}
|
||||
],
|
||||
"total": 1,
|
||||
"query": "mountain sunset"
|
||||
}
|
||||
}
|
||||
56
src/schemas/team.py
Normal file
56
src/schemas/team.py
Normal file
@ -0,0 +1,56 @@
|
||||
from typing import List, Optional
|
||||
from datetime import datetime
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
class TeamBase(BaseModel):
|
||||
"""Base schema for team data"""
|
||||
name: str = Field(..., description="Team name", min_length=1, max_length=100)
|
||||
description: Optional[str] = Field(None, description="Team description", max_length=500)
|
||||
|
||||
class TeamCreate(TeamBase):
|
||||
"""Schema for creating a team"""
|
||||
pass
|
||||
|
||||
class TeamUpdate(BaseModel):
|
||||
"""Schema for updating a team"""
|
||||
name: Optional[str] = Field(None, description="Team name", min_length=1, max_length=100)
|
||||
description: Optional[str] = Field(None, description="Team description", max_length=500)
|
||||
|
||||
class TeamResponse(TeamBase):
|
||||
"""Schema for team response"""
|
||||
id: str
|
||||
created_at: datetime
|
||||
updated_at: Optional[datetime] = None
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
schema_extra = {
|
||||
"example": {
|
||||
"id": "507f1f77bcf86cd799439011",
|
||||
"name": "Marketing Team",
|
||||
"description": "Team responsible for marketing campaigns",
|
||||
"created_at": "2023-10-20T10:00:00",
|
||||
"updated_at": None
|
||||
}
|
||||
}
|
||||
|
||||
class TeamListResponse(BaseModel):
|
||||
"""Schema for team list response"""
|
||||
teams: List[TeamResponse]
|
||||
total: int
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
"example": {
|
||||
"teams": [
|
||||
{
|
||||
"id": "507f1f77bcf86cd799439011",
|
||||
"name": "Marketing Team",
|
||||
"description": "Team responsible for marketing campaigns",
|
||||
"created_at": "2023-10-20T10:00:00",
|
||||
"updated_at": None
|
||||
}
|
||||
],
|
||||
"total": 1
|
||||
}
|
||||
}
|
||||
74
src/schemas/user.py
Normal file
74
src/schemas/user.py
Normal file
@ -0,0 +1,74 @@
|
||||
from typing import List, Optional
|
||||
from datetime import datetime
|
||||
from pydantic import BaseModel, EmailStr, Field
|
||||
|
||||
class UserBase(BaseModel):
|
||||
"""Base schema for user data"""
|
||||
email: EmailStr = Field(..., description="User email")
|
||||
name: str = Field(..., description="User name", min_length=1, max_length=100)
|
||||
team_id: str = Field(..., description="Team ID")
|
||||
|
||||
class UserCreate(UserBase):
|
||||
"""Schema for creating a user"""
|
||||
is_admin: bool = Field(False, description="Whether the user is an admin")
|
||||
|
||||
class UserUpdate(BaseModel):
|
||||
"""Schema for updating a user"""
|
||||
email: Optional[EmailStr] = Field(None, description="User email")
|
||||
name: Optional[str] = Field(None, description="User name", min_length=1, max_length=100)
|
||||
team_id: Optional[str] = Field(None, description="Team ID")
|
||||
is_active: Optional[bool] = Field(None, description="Whether the user is active")
|
||||
is_admin: Optional[bool] = Field(None, description="Whether the user is an admin")
|
||||
|
||||
class UserResponse(BaseModel):
|
||||
"""Schema for user response"""
|
||||
id: str
|
||||
email: EmailStr
|
||||
name: str
|
||||
team_id: str
|
||||
is_active: bool
|
||||
is_admin: bool
|
||||
created_at: datetime
|
||||
updated_at: Optional[datetime] = None
|
||||
last_login: Optional[datetime] = None
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
schema_extra = {
|
||||
"example": {
|
||||
"id": "507f1f77bcf86cd799439011",
|
||||
"email": "user@example.com",
|
||||
"name": "John Doe",
|
||||
"team_id": "507f1f77bcf86cd799439022",
|
||||
"is_active": True,
|
||||
"is_admin": False,
|
||||
"created_at": "2023-10-20T10:00:00",
|
||||
"updated_at": None,
|
||||
"last_login": None
|
||||
}
|
||||
}
|
||||
|
||||
class UserListResponse(BaseModel):
|
||||
"""Schema for user list response"""
|
||||
users: List[UserResponse]
|
||||
total: int
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
"example": {
|
||||
"users": [
|
||||
{
|
||||
"id": "507f1f77bcf86cd799439011",
|
||||
"email": "user@example.com",
|
||||
"name": "John Doe",
|
||||
"team_id": "507f1f77bcf86cd799439022",
|
||||
"is_active": True,
|
||||
"is_admin": False,
|
||||
"created_at": "2023-10-20T10:00:00",
|
||||
"updated_at": None,
|
||||
"last_login": None
|
||||
}
|
||||
],
|
||||
"total": 1
|
||||
}
|
||||
}
|
||||
0
src/services/__init__.py
Normal file
0
src/services/__init__.py
Normal file
136
src/services/embedding_service.py
Normal file
136
src/services/embedding_service.py
Normal file
@ -0,0 +1,136 @@
|
||||
import io
|
||||
import logging
|
||||
import os
|
||||
from typing import List, Dict, Any, Union, Optional
|
||||
import torch
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from transformers import CLIPProcessor, CLIPModel
|
||||
|
||||
from src.core.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class EmbeddingService:
|
||||
"""Service for generating image and text embeddings"""
|
||||
|
||||
def __init__(self):
|
||||
self.model = None
|
||||
self.processor = None
|
||||
self.model_name = "openai/clip-vit-base-patch32"
|
||||
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
self.embedding_dim = 512 # Dimension of CLIP's embeddings
|
||||
|
||||
def _load_model(self):
|
||||
"""
|
||||
Load the CLIP model if not already loaded
|
||||
"""
|
||||
if self.model is None:
|
||||
try:
|
||||
logger.info(f"Loading CLIP model on {self.device}")
|
||||
self.model = CLIPModel.from_pretrained(self.model_name).to(self.device)
|
||||
self.processor = CLIPProcessor.from_pretrained(self.model_name)
|
||||
logger.info("CLIP model loaded successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading CLIP model: {e}")
|
||||
raise
|
||||
|
||||
def generate_image_embedding(self, image_data: bytes) -> List[float]:
|
||||
"""
|
||||
Generate embedding for an image
|
||||
|
||||
Args:
|
||||
image_data: Binary image data
|
||||
|
||||
Returns:
|
||||
Image embedding as a list of floats
|
||||
"""
|
||||
try:
|
||||
self._load_model()
|
||||
|
||||
# Load the image
|
||||
image = Image.open(io.BytesIO(image_data))
|
||||
|
||||
# Process the image for the model
|
||||
inputs = self.processor(
|
||||
images=image,
|
||||
return_tensors="pt"
|
||||
).to(self.device)
|
||||
|
||||
# Generate the embedding
|
||||
with torch.no_grad():
|
||||
image_features = self.model.get_image_features(**inputs)
|
||||
|
||||
# Normalize the embedding
|
||||
image_embedding = image_features / image_features.norm(dim=1, keepdim=True)
|
||||
|
||||
# Convert to list of floats
|
||||
embedding = image_embedding.cpu().numpy().tolist()[0]
|
||||
|
||||
return embedding
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating image embedding: {e}")
|
||||
raise
|
||||
|
||||
def generate_text_embedding(self, text: str) -> List[float]:
|
||||
"""
|
||||
Generate embedding for a text query
|
||||
|
||||
Args:
|
||||
text: Text query
|
||||
|
||||
Returns:
|
||||
Text embedding as a list of floats
|
||||
"""
|
||||
try:
|
||||
self._load_model()
|
||||
|
||||
# Process the text for the model
|
||||
inputs = self.processor(
|
||||
text=text,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
truncation=True
|
||||
).to(self.device)
|
||||
|
||||
# Generate the embedding
|
||||
with torch.no_grad():
|
||||
text_features = self.model.get_text_features(**inputs)
|
||||
|
||||
# Normalize the embedding
|
||||
text_embedding = text_features / text_features.norm(dim=1, keepdim=True)
|
||||
|
||||
# Convert to list of floats
|
||||
embedding = text_embedding.cpu().numpy().tolist()[0]
|
||||
|
||||
return embedding
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating text embedding: {e}")
|
||||
raise
|
||||
|
||||
def calculate_similarity(self, embedding1: List[float], embedding2: List[float]) -> float:
|
||||
"""
|
||||
Calculate cosine similarity between two embeddings
|
||||
|
||||
Args:
|
||||
embedding1: First embedding
|
||||
embedding2: Second embedding
|
||||
|
||||
Returns:
|
||||
Cosine similarity (0-1)
|
||||
"""
|
||||
try:
|
||||
# Convert to numpy arrays
|
||||
vec1 = np.array(embedding1)
|
||||
vec2 = np.array(embedding2)
|
||||
|
||||
# Calculate cosine similarity
|
||||
similarity = np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2))
|
||||
|
||||
return float(similarity)
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating similarity: {e}")
|
||||
raise
|
||||
|
||||
# Create a singleton service
|
||||
embedding_service = EmbeddingService()
|
||||
143
src/services/image_processor.py
Normal file
143
src/services/image_processor.py
Normal file
@ -0,0 +1,143 @@
|
||||
import logging
|
||||
import io
|
||||
from typing import Dict, Any, Tuple, Optional
|
||||
from PIL import Image, ImageOps
|
||||
from PIL.ExifTags import TAGS
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class ImageProcessor:
|
||||
"""Service for image processing operations"""
|
||||
|
||||
def extract_metadata(self, image_data: bytes) -> Dict[str, Any]:
|
||||
"""
|
||||
Extract metadata from an image
|
||||
|
||||
Args:
|
||||
image_data: Binary image data
|
||||
|
||||
Returns:
|
||||
Dictionary of metadata
|
||||
"""
|
||||
try:
|
||||
metadata = {}
|
||||
|
||||
# Open the image with PIL
|
||||
with Image.open(io.BytesIO(image_data)) as img:
|
||||
# Basic image info
|
||||
metadata["width"] = img.width
|
||||
metadata["height"] = img.height
|
||||
metadata["format"] = img.format
|
||||
metadata["mode"] = img.mode
|
||||
|
||||
# Try to extract EXIF data if available
|
||||
if hasattr(img, '_getexif') and img._getexif():
|
||||
exif = {}
|
||||
for tag_id, value in img._getexif().items():
|
||||
tag = TAGS.get(tag_id, tag_id)
|
||||
# Skip binary data that might be in EXIF
|
||||
if isinstance(value, (bytes, bytearray)):
|
||||
value = "binary data"
|
||||
exif[tag] = value
|
||||
metadata["exif"] = exif
|
||||
|
||||
return metadata
|
||||
except Exception as e:
|
||||
logger.error(f"Error extracting image metadata: {e}")
|
||||
# Return at least an empty dict rather than failing
|
||||
return {}
|
||||
|
||||
def resize_image(self, image_data: bytes, max_width: int = 1200, max_height: int = 1200) -> Tuple[bytes, Dict[str, Any]]:
|
||||
"""
|
||||
Resize an image while maintaining aspect ratio
|
||||
|
||||
Args:
|
||||
image_data: Binary image data
|
||||
max_width: Maximum width
|
||||
max_height: Maximum height
|
||||
|
||||
Returns:
|
||||
Tuple of (resized_image_data, metadata)
|
||||
"""
|
||||
try:
|
||||
# Open the image with PIL
|
||||
img = Image.open(io.BytesIO(image_data))
|
||||
|
||||
# Get original size
|
||||
original_width, original_height = img.size
|
||||
|
||||
# Calculate new size
|
||||
width_ratio = max_width / original_width if original_width > max_width else 1
|
||||
height_ratio = max_height / original_height if original_height > max_height else 1
|
||||
|
||||
# Use the smaller ratio to ensure image fits within max dimensions
|
||||
ratio = min(width_ratio, height_ratio)
|
||||
|
||||
# Only resize if the image is larger than the max dimensions
|
||||
if ratio < 1:
|
||||
new_width = int(original_width * ratio)
|
||||
new_height = int(original_height * ratio)
|
||||
img = img.resize((new_width, new_height), Image.LANCZOS)
|
||||
|
||||
logger.info(f"Resized image from {original_width}x{original_height} to {new_width}x{new_height}")
|
||||
else:
|
||||
logger.info(f"No resizing needed, image size {original_width}x{original_height}")
|
||||
|
||||
# Save to bytes
|
||||
output = io.BytesIO()
|
||||
img.save(output, format=img.format)
|
||||
resized_data = output.getvalue()
|
||||
|
||||
# Get new metadata
|
||||
metadata = {
|
||||
"width": img.width,
|
||||
"height": img.height,
|
||||
"format": img.format,
|
||||
"mode": img.mode,
|
||||
"resized": ratio < 1
|
||||
}
|
||||
|
||||
return resized_data, metadata
|
||||
except Exception as e:
|
||||
logger.error(f"Error resizing image: {e}")
|
||||
# Return original data on error
|
||||
return image_data, {"error": str(e)}
|
||||
|
||||
def is_image(self, mime_type: str) -> bool:
|
||||
"""
|
||||
Check if a file is an image based on MIME type
|
||||
|
||||
Args:
|
||||
mime_type: File MIME type
|
||||
|
||||
Returns:
|
||||
True if the file is an image
|
||||
"""
|
||||
return mime_type.startswith('image/')
|
||||
|
||||
def validate_image(self, image_data: bytes, mime_type: str) -> Tuple[bool, Optional[str]]:
|
||||
"""
|
||||
Validate an image file
|
||||
|
||||
Args:
|
||||
image_data: Binary image data
|
||||
mime_type: File MIME type
|
||||
|
||||
Returns:
|
||||
Tuple of (is_valid, error_message)
|
||||
"""
|
||||
if not self.is_image(mime_type):
|
||||
return False, "File is not an image"
|
||||
|
||||
try:
|
||||
# Try opening the image with PIL
|
||||
with Image.open(io.BytesIO(image_data)) as img:
|
||||
# Verify image can be read
|
||||
img.verify()
|
||||
return True, None
|
||||
except Exception as e:
|
||||
logger.error(f"Invalid image: {e}")
|
||||
return False, f"Invalid image: {str(e)}"
|
||||
|
||||
# Create a singleton service
|
||||
image_processor = ImageProcessor()
|
||||
227
src/services/storage.py
Normal file
227
src/services/storage.py
Normal file
@ -0,0 +1,227 @@
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional, BinaryIO, Tuple
|
||||
from urllib.parse import quote
|
||||
|
||||
from fastapi import UploadFile
|
||||
from google.cloud import storage
|
||||
from google.oauth2 import service_account
|
||||
from PIL import Image
|
||||
|
||||
from src.core.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class StorageService:
|
||||
"""Service for Google Cloud Storage operations"""
|
||||
|
||||
def __init__(self):
|
||||
self.bucket_name = settings.GCS_BUCKET_NAME
|
||||
self.client = self._create_storage_client()
|
||||
self.bucket = self._get_or_create_bucket()
|
||||
|
||||
def _create_storage_client(self) -> storage.Client:
|
||||
"""
|
||||
Create a Google Cloud Storage client
|
||||
|
||||
Returns:
|
||||
Google Cloud Storage client
|
||||
"""
|
||||
try:
|
||||
# Check if credentials file exists
|
||||
if os.path.exists(settings.GCS_CREDENTIALS_FILE):
|
||||
# Use credentials from file
|
||||
credentials = service_account.Credentials.from_service_account_file(
|
||||
settings.GCS_CREDENTIALS_FILE
|
||||
)
|
||||
return storage.Client(credentials=credentials)
|
||||
else:
|
||||
# Use default credentials (useful for Cloud Run, where credentials
|
||||
# are provided by the environment)
|
||||
logger.info("Using default credentials for GCS")
|
||||
return storage.Client()
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating GCS client: {e}")
|
||||
raise
|
||||
|
||||
def _get_or_create_bucket(self) -> storage.Bucket:
|
||||
"""
|
||||
Get or create a GCS bucket
|
||||
|
||||
Returns:
|
||||
Google Cloud Storage bucket
|
||||
"""
|
||||
try:
|
||||
# Check if bucket exists
|
||||
bucket = self.client.bucket(self.bucket_name)
|
||||
if not bucket.exists():
|
||||
logger.info(f"Creating bucket: {self.bucket_name}")
|
||||
bucket = self.client.create_bucket(self.bucket_name)
|
||||
return bucket
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting or creating bucket: {e}")
|
||||
raise
|
||||
|
||||
async def upload_file(self, file: UploadFile, team_id: str) -> Tuple[str, str, int, dict]:
|
||||
"""
|
||||
Upload a file to Google Cloud Storage
|
||||
|
||||
Args:
|
||||
file: File to upload
|
||||
team_id: Team ID for organization
|
||||
|
||||
Returns:
|
||||
Tuple of (storage_path, content_type, file_size, metadata)
|
||||
"""
|
||||
try:
|
||||
# Generate a unique filename
|
||||
original_filename = file.filename
|
||||
extension = os.path.splitext(original_filename)[1] if original_filename else ""
|
||||
unique_filename = f"{uuid.uuid4().hex}{extension}"
|
||||
|
||||
# Create a storage path with team ID for organization
|
||||
storage_path = f"{team_id}/{unique_filename}"
|
||||
|
||||
# Get file content
|
||||
content = await file.read()
|
||||
file_size = len(content)
|
||||
|
||||
# Create a blob in the bucket
|
||||
blob = self.bucket.blob(storage_path)
|
||||
|
||||
# Set content type
|
||||
content_type = file.content_type
|
||||
if content_type:
|
||||
blob.content_type = content_type
|
||||
|
||||
# Set metadata
|
||||
metadata = {}
|
||||
try:
|
||||
# Extract image metadata if it's an image
|
||||
if content_type and content_type.startswith('image/'):
|
||||
# Create a temporary file to read with PIL
|
||||
with Image.open(BinaryIO(content)) as img:
|
||||
metadata = {
|
||||
'width': img.width,
|
||||
'height': img.height,
|
||||
'format': img.format,
|
||||
'mode': img.mode
|
||||
}
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to extract image metadata: {e}")
|
||||
|
||||
# Set custom metadata
|
||||
blob.metadata = {
|
||||
'original_filename': original_filename,
|
||||
'upload_time': datetime.utcnow().isoformat(),
|
||||
'team_id': team_id
|
||||
}
|
||||
|
||||
# Upload the file
|
||||
blob.upload_from_string(content, content_type=content_type)
|
||||
|
||||
logger.info(f"File uploaded: {storage_path}")
|
||||
|
||||
# Seek back to the beginning for future reads
|
||||
await file.seek(0)
|
||||
|
||||
return storage_path, content_type, file_size, metadata
|
||||
except Exception as e:
|
||||
logger.error(f"Error uploading file: {e}")
|
||||
raise
|
||||
|
||||
def get_file(self, storage_path: str) -> Optional[bytes]:
|
||||
"""
|
||||
Get a file from Google Cloud Storage
|
||||
|
||||
Args:
|
||||
storage_path: Storage path of the file
|
||||
|
||||
Returns:
|
||||
File content or None if not found
|
||||
"""
|
||||
try:
|
||||
blob = self.bucket.blob(storage_path)
|
||||
if not blob.exists():
|
||||
logger.warning(f"File not found: {storage_path}")
|
||||
return None
|
||||
|
||||
return blob.download_as_bytes()
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting file: {e}")
|
||||
raise
|
||||
|
||||
def delete_file(self, storage_path: str) -> bool:
|
||||
"""
|
||||
Delete a file from Google Cloud Storage
|
||||
|
||||
Args:
|
||||
storage_path: Storage path of the file
|
||||
|
||||
Returns:
|
||||
True if file was deleted, False if not found
|
||||
"""
|
||||
try:
|
||||
blob = self.bucket.blob(storage_path)
|
||||
if not blob.exists():
|
||||
logger.warning(f"File not found for deletion: {storage_path}")
|
||||
return False
|
||||
|
||||
blob.delete()
|
||||
logger.info(f"File deleted: {storage_path}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting file: {e}")
|
||||
raise
|
||||
|
||||
def generate_public_url(self, storage_path: str) -> str:
|
||||
"""
|
||||
Generate a public URL for a file
|
||||
|
||||
Args:
|
||||
storage_path: Storage path of the file
|
||||
|
||||
Returns:
|
||||
Public URL
|
||||
"""
|
||||
try:
|
||||
blob = self.bucket.blob(storage_path)
|
||||
return blob.public_url
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating public URL: {e}")
|
||||
raise
|
||||
|
||||
def generate_signed_url(self, storage_path: str, expiration_minutes: int = 30) -> Optional[str]:
|
||||
"""
|
||||
Generate a signed URL for temporary access to a file
|
||||
|
||||
Args:
|
||||
storage_path: Storage path of the file
|
||||
expiration_minutes: Minutes until the URL expires
|
||||
|
||||
Returns:
|
||||
Signed URL or None if file not found
|
||||
"""
|
||||
try:
|
||||
blob = self.bucket.blob(storage_path)
|
||||
if not blob.exists():
|
||||
logger.warning(f"File not found for signed URL: {storage_path}")
|
||||
return None
|
||||
|
||||
expiration = datetime.utcnow() + timedelta(minutes=expiration_minutes)
|
||||
|
||||
url = blob.generate_signed_url(
|
||||
version="v4",
|
||||
expiration=expiration,
|
||||
method="GET"
|
||||
)
|
||||
|
||||
return url
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating signed URL: {e}")
|
||||
raise
|
||||
|
||||
# Create a singleton service
|
||||
storage_service = StorageService()
|
||||
182
src/services/vector_store.py
Normal file
182
src/services/vector_store.py
Normal file
@ -0,0 +1,182 @@
|
||||
import logging
|
||||
from typing import List, Dict, Any, Optional, Tuple
|
||||
import pinecone
|
||||
from bson import ObjectId
|
||||
|
||||
from src.core.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class VectorStoreService:
|
||||
"""Service for managing vector embeddings in Pinecone"""
|
||||
|
||||
def __init__(self):
|
||||
self.api_key = settings.VECTOR_DB_API_KEY
|
||||
self.environment = settings.VECTOR_DB_ENVIRONMENT
|
||||
self.index_name = settings.VECTOR_DB_INDEX_NAME
|
||||
self.dimension = 512 # CLIP model embedding dimension
|
||||
self.initialized = False
|
||||
self.index = None
|
||||
|
||||
def initialize(self):
|
||||
"""
|
||||
Initialize Pinecone connection and create index if needed
|
||||
"""
|
||||
if self.initialized:
|
||||
return
|
||||
|
||||
if not self.api_key or not self.environment:
|
||||
logger.warning("Pinecone API key or environment not provided, vector search disabled")
|
||||
return
|
||||
|
||||
try:
|
||||
logger.info(f"Initializing Pinecone with environment {self.environment}")
|
||||
|
||||
# Initialize Pinecone
|
||||
pinecone.init(
|
||||
api_key=self.api_key,
|
||||
environment=self.environment
|
||||
)
|
||||
|
||||
# Check if index exists
|
||||
if self.index_name not in pinecone.list_indexes():
|
||||
logger.info(f"Creating Pinecone index: {self.index_name}")
|
||||
|
||||
# Create index
|
||||
pinecone.create_index(
|
||||
name=self.index_name,
|
||||
dimension=self.dimension,
|
||||
metric="cosine"
|
||||
)
|
||||
|
||||
# Connect to index
|
||||
self.index = pinecone.Index(self.index_name)
|
||||
self.initialized = True
|
||||
logger.info("Pinecone initialized successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing Pinecone: {e}")
|
||||
# Don't raise - we want to gracefully handle this and fall back to non-vector search
|
||||
|
||||
def store_embedding(self, image_id: str, team_id: str, embedding: List[float], metadata: Dict[str, Any] = None) -> Optional[str]:
|
||||
"""
|
||||
Store an embedding in Pinecone
|
||||
|
||||
Args:
|
||||
image_id: Image ID
|
||||
team_id: Team ID
|
||||
embedding: Image embedding
|
||||
metadata: Additional metadata
|
||||
|
||||
Returns:
|
||||
Vector ID if successful, None otherwise
|
||||
"""
|
||||
self.initialize()
|
||||
|
||||
if not self.initialized:
|
||||
logger.warning("Pinecone not initialized, cannot store embedding")
|
||||
return None
|
||||
|
||||
try:
|
||||
# Create metadata dict
|
||||
meta = {
|
||||
"image_id": image_id,
|
||||
"team_id": team_id
|
||||
}
|
||||
|
||||
if metadata:
|
||||
meta.update(metadata)
|
||||
|
||||
# Create a unique vector ID
|
||||
vector_id = f"{team_id}_{image_id}"
|
||||
|
||||
# Upsert the vector
|
||||
self.index.upsert(
|
||||
vectors=[
|
||||
(vector_id, embedding, meta)
|
||||
]
|
||||
)
|
||||
|
||||
logger.info(f"Stored embedding for image {image_id}")
|
||||
return vector_id
|
||||
except Exception as e:
|
||||
logger.error(f"Error storing embedding: {e}")
|
||||
return None
|
||||
|
||||
def search_by_embedding(self, team_id: str, query_embedding: List[float], limit: int = 10) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Search for similar images by embedding
|
||||
|
||||
Args:
|
||||
team_id: Team ID
|
||||
query_embedding: Query embedding
|
||||
limit: Maximum number of results
|
||||
|
||||
Returns:
|
||||
List of results with image ID and similarity score
|
||||
"""
|
||||
self.initialize()
|
||||
|
||||
if not self.initialized:
|
||||
logger.warning("Pinecone not initialized, cannot search by embedding")
|
||||
return []
|
||||
|
||||
try:
|
||||
# Create filter for team_id
|
||||
filter_dict = {
|
||||
"team_id": {"$eq": team_id}
|
||||
}
|
||||
|
||||
# Query the index
|
||||
results = self.index.query(
|
||||
vector=query_embedding,
|
||||
filter=filter_dict,
|
||||
top_k=limit,
|
||||
include_metadata=True
|
||||
)
|
||||
|
||||
# Format the results
|
||||
formatted_results = []
|
||||
for match in results.matches:
|
||||
formatted_results.append({
|
||||
"image_id": match.metadata["image_id"],
|
||||
"score": match.score,
|
||||
"metadata": match.metadata
|
||||
})
|
||||
|
||||
return formatted_results
|
||||
except Exception as e:
|
||||
logger.error(f"Error searching by embedding: {e}")
|
||||
return []
|
||||
|
||||
def delete_embedding(self, image_id: str, team_id: str) -> bool:
|
||||
"""
|
||||
Delete an embedding from Pinecone
|
||||
|
||||
Args:
|
||||
image_id: Image ID
|
||||
team_id: Team ID
|
||||
|
||||
Returns:
|
||||
True if successful, False otherwise
|
||||
"""
|
||||
self.initialize()
|
||||
|
||||
if not self.initialized:
|
||||
logger.warning("Pinecone not initialized, cannot delete embedding")
|
||||
return False
|
||||
|
||||
try:
|
||||
# Create the vector ID
|
||||
vector_id = f"{team_id}_{image_id}"
|
||||
|
||||
# Delete the vector
|
||||
self.index.delete(ids=[vector_id])
|
||||
|
||||
logger.info(f"Deleted embedding for image {image_id}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting embedding: {e}")
|
||||
return False
|
||||
|
||||
# Create a singleton service
|
||||
vector_store_service = VectorStoreService()
|
||||
Loading…
x
Reference in New Issue
Block a user