Spaces:
Runtime error
Runtime error
| import os | |
| from pathlib import Path | |
| import json | |
| from flask import Flask, request, jsonify, g | |
| from flask_expects_json import expects_json | |
| from flask_cors import CORS | |
| from PIL import Image | |
| from huggingface_hub import Repository | |
| from flask_apscheduler import APScheduler | |
| import shutil | |
| import sqlite3 | |
| import subprocess | |
| from jsonschema import ValidationError | |
| MODE = os.environ.get('FLASK_ENV', 'production') | |
| IS_DEV = MODE == 'development' | |
| app = Flask(__name__, static_url_path='/static') | |
| app.config['JSONIFY_PRETTYPRINT_REGULAR'] = False | |
| schema = { | |
| "type": "object", | |
| "properties": { | |
| "prompt": {"type": "string"}, | |
| "images": { | |
| "type": "array", | |
| "items": { | |
| "type": "object", | |
| "minProperties": 2, | |
| "maxProperties": 2, | |
| "properties": { | |
| "colors": { | |
| "type": "array", | |
| "items": { | |
| "type": "string" | |
| }, | |
| "maxItems": 5, | |
| "minItems": 5 | |
| }, | |
| "imgURL": {"type": "string"}} | |
| } | |
| } | |
| }, | |
| "minProperties": 2, | |
| "maxProperties": 2 | |
| } | |
| CORS(app) | |
| DB_FILE = Path("./data.db") | |
| TOKEN = os.environ.get('HUGGING_FACE_HUB_TOKEN') | |
| repo = Repository( | |
| local_dir="data", | |
| repo_type="dataset", | |
| clone_from="huggingface-projects/color-palettes-sd", | |
| use_auth_token=TOKEN | |
| ) | |
| repo.git_pull() | |
| # copy db on db to local path | |
| shutil.copyfile("./data/data.db", DB_FILE) | |
| db = sqlite3.connect(DB_FILE) | |
| try: | |
| data = db.execute("SELECT * FROM palettes").fetchall() | |
| if IS_DEV: | |
| print(f"Loaded {len(data)} palettes from local db") | |
| db.close() | |
| except sqlite3.OperationalError: | |
| db.execute( | |
| 'CREATE TABLE palettes (id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, data json, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL)') | |
| db.commit() | |
| def get_db(): | |
| db = getattr(g, '_database', None) | |
| if db is None: | |
| db = g._database = sqlite3.connect(DB_FILE) | |
| db.row_factory = sqlite3.Row | |
| return db | |
| def close_connection(exception): | |
| db = getattr(g, '_database', None) | |
| if db is not None: | |
| db.close() | |
| def update_repository(): | |
| repo.git_pull() | |
| # copy db on db to local path | |
| shutil.copyfile(DB_FILE, "./data/data.db") | |
| with sqlite3.connect("./data/data.db") as db: | |
| db.row_factory = sqlite3.Row | |
| palettes = db.execute("SELECT * FROM palettes").fetchall() | |
| data = [{'id': row['id'], 'data': json.loads( | |
| row['data']), 'created_at': row['created_at']} for row in palettes] | |
| with open('./data/data.json', 'w') as f: | |
| json.dump(data, f, separators=(',', ':')) | |
| print("Updating repository") | |
| subprocess.Popen( | |
| "git add . && git commit --amend -m 'update' && git push --force", cwd="./data", shell=True) | |
| repo.push_to_hub(blocking=False) | |
| def index(): | |
| return app.send_static_file('index.html') | |
| def push(): | |
| if (request.headers['token'] == TOKEN): | |
| update_repository() | |
| return jsonify({'success': True}) | |
| else: | |
| return "Error", 401 | |
| def getAllData(): | |
| palettes = get_db().execute("SELECT * FROM palettes").fetchall() | |
| data = [{'id': row['id'], 'data': json.loads( | |
| row['data']), 'created_at': row['created_at']} for row in palettes] | |
| return data | |
| def getdata(): | |
| return jsonify(getAllData()) | |
| def create(): | |
| data = g.data | |
| db = get_db() | |
| cursor = db.cursor() | |
| cursor.execute("INSERT INTO palettes(data) VALUES (?)", [json.dumps(data)]) | |
| db.commit() | |
| return jsonify(getAllData()) | |
| def bad_request(error): | |
| if isinstance(error.description, ValidationError): | |
| original_error = error.description | |
| return jsonify({'error': original_error.message}), 400 | |
| return error | |
| if __name__ == '__main__': | |
| if not IS_DEV: | |
| print("Starting scheduler -- Running Production") | |
| scheduler = APScheduler() | |
| scheduler.add_job(id='Update Dataset Repository', | |
| func=update_repository, trigger='interval', hours=1) | |
| scheduler.start() | |
| else: | |
| print("Not Starting scheduler -- Running Development") | |
| app.run(host='0.0.0.0', port=int( | |
| os.environ.get('PORT', 7860)), debug=True, use_reloader=IS_DEV) | |