Spaces:
Sleeping
Sleeping
Migrate to SQLAlchemy and Alembic for database migrations. (#19)
Browse files* Migrate to SQLAlchemy and Alembic for database migrations.
* remove db.init() call from gradio app
* run migrations at app startup
* fix run migrations
* Convert SQLAlchemy rows to simple lists
* add missing config key
- alembic.ini +44 -0
- alembic/README +1 -0
- alembic/env.py +78 -0
- alembic/script.py.mako +28 -0
- alembic/versions/0b01b44ab005_initial.py +56 -0
- app.py +0 -5
- blossomtune_gradio/__main__.py +4 -0
- blossomtune_gradio/config.py +4 -0
- blossomtune_gradio/database.py +61 -30
- blossomtune_gradio/federation.py +118 -125
- blossomtune_gradio/gradio_app.py +0 -2
- blossomtune_gradio/processing.py +22 -14
- blossomtune_gradio/ui/callbacks.py +26 -15
- flower_apps/quickstart_huggingface/pyproject.toml +1 -0
- pyproject.toml +86 -0
- tests/conftest.py +53 -11
- tests/test_federation.py +118 -134
- tests/test_processing.py +40 -18
- uv.lock +0 -0
alembic.ini
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# A generic, single database configuration.
|
| 2 |
+
|
| 3 |
+
[alembic]
|
| 4 |
+
|
| 5 |
+
# database URL. This is consumed by the user-maintained env.py script only.
|
| 6 |
+
# other means of configuring database URLs may be customized within the env.py
|
| 7 |
+
# file.
|
| 8 |
+
sqlalchemy.url = driver://user:pass@localhost/dbname
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
# Logging configuration
|
| 12 |
+
[loggers]
|
| 13 |
+
keys = root,sqlalchemy,alembic
|
| 14 |
+
|
| 15 |
+
[handlers]
|
| 16 |
+
keys = console
|
| 17 |
+
|
| 18 |
+
[formatters]
|
| 19 |
+
keys = generic
|
| 20 |
+
|
| 21 |
+
[logger_root]
|
| 22 |
+
level = WARNING
|
| 23 |
+
handlers = console
|
| 24 |
+
qualname =
|
| 25 |
+
|
| 26 |
+
[logger_sqlalchemy]
|
| 27 |
+
level = WARNING
|
| 28 |
+
handlers =
|
| 29 |
+
qualname = sqlalchemy.engine
|
| 30 |
+
|
| 31 |
+
[logger_alembic]
|
| 32 |
+
level = INFO
|
| 33 |
+
handlers =
|
| 34 |
+
qualname = alembic
|
| 35 |
+
|
| 36 |
+
[handler_console]
|
| 37 |
+
class = StreamHandler
|
| 38 |
+
args = (sys.stderr,)
|
| 39 |
+
level = NOTSET
|
| 40 |
+
formatter = generic
|
| 41 |
+
|
| 42 |
+
[formatter_generic]
|
| 43 |
+
format = %(levelname)-5.5s [%(name)s] %(message)s
|
| 44 |
+
datefmt = %H:%M:%S
|
alembic/README
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
pyproject configuration, based on the generic configuration.
|
alembic/env.py
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from logging.config import fileConfig
|
| 2 |
+
|
| 3 |
+
from sqlalchemy import engine_from_config
|
| 4 |
+
from sqlalchemy import pool
|
| 5 |
+
|
| 6 |
+
from alembic import context
|
| 7 |
+
from blossomtune_gradio import config as cfg
|
| 8 |
+
from blossomtune_gradio.database import Base
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
# this is the Alembic Config object, which provides
|
| 12 |
+
# access to the values within the .ini file in use.
|
| 13 |
+
config = context.config
|
| 14 |
+
config.set_main_option("sqlalchemy.url", cfg.SQLALCHEMY_URL)
|
| 15 |
+
|
| 16 |
+
# Interpret the config file for Python logging.
|
| 17 |
+
# This line sets up loggers basically.
|
| 18 |
+
if config.config_file_name is not None:
|
| 19 |
+
fileConfig(config.config_file_name)
|
| 20 |
+
|
| 21 |
+
# add your model's MetaData object here
|
| 22 |
+
# for 'autogenerate' support
|
| 23 |
+
target_metadata = Base.metadata
|
| 24 |
+
|
| 25 |
+
# other values from the config, defined by the needs of env.py,
|
| 26 |
+
# can be acquired:
|
| 27 |
+
# my_important_option = config.get_main_option("my_important_option")
|
| 28 |
+
# ... etc.
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def run_migrations_offline() -> None:
|
| 32 |
+
"""Run migrations in 'offline' mode.
|
| 33 |
+
|
| 34 |
+
This configures the context with just a URL
|
| 35 |
+
and not an Engine, though an Engine is acceptable
|
| 36 |
+
here as well. By skipping the Engine creation
|
| 37 |
+
we don't even need a DBAPI to be available.
|
| 38 |
+
|
| 39 |
+
Calls to context.execute() here emit the given string to the
|
| 40 |
+
script output.
|
| 41 |
+
|
| 42 |
+
"""
|
| 43 |
+
url = config.get_main_option("sqlalchemy.url")
|
| 44 |
+
context.configure(
|
| 45 |
+
url=url,
|
| 46 |
+
target_metadata=target_metadata,
|
| 47 |
+
literal_binds=True,
|
| 48 |
+
dialect_opts={"paramstyle": "named"},
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
with context.begin_transaction():
|
| 52 |
+
context.run_migrations()
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def run_migrations_online() -> None:
|
| 56 |
+
"""Run migrations in 'online' mode.
|
| 57 |
+
|
| 58 |
+
In this scenario we need to create an Engine
|
| 59 |
+
and associate a connection with the context.
|
| 60 |
+
|
| 61 |
+
"""
|
| 62 |
+
connectable = engine_from_config(
|
| 63 |
+
config.get_section(config.config_ini_section, {}),
|
| 64 |
+
prefix="sqlalchemy.",
|
| 65 |
+
poolclass=pool.NullPool,
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
with connectable.connect() as connection:
|
| 69 |
+
context.configure(connection=connection, target_metadata=target_metadata)
|
| 70 |
+
|
| 71 |
+
with context.begin_transaction():
|
| 72 |
+
context.run_migrations()
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
if context.is_offline_mode():
|
| 76 |
+
run_migrations_offline()
|
| 77 |
+
else:
|
| 78 |
+
run_migrations_online()
|
alembic/script.py.mako
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""${message}
|
| 2 |
+
|
| 3 |
+
Revision ID: ${up_revision}
|
| 4 |
+
Revises: ${down_revision | comma,n}
|
| 5 |
+
Create Date: ${create_date}
|
| 6 |
+
|
| 7 |
+
"""
|
| 8 |
+
from typing import Sequence, Union
|
| 9 |
+
|
| 10 |
+
from alembic import op
|
| 11 |
+
import sqlalchemy as sa
|
| 12 |
+
${imports if imports else ""}
|
| 13 |
+
|
| 14 |
+
# revision identifiers, used by Alembic.
|
| 15 |
+
revision: str = ${repr(up_revision)}
|
| 16 |
+
down_revision: Union[str, Sequence[str], None] = ${repr(down_revision)}
|
| 17 |
+
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
|
| 18 |
+
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def upgrade() -> None:
|
| 22 |
+
"""Upgrade schema."""
|
| 23 |
+
${upgrades if upgrades else "pass"}
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def downgrade() -> None:
|
| 27 |
+
"""Downgrade schema."""
|
| 28 |
+
${downgrades if downgrades else "pass"}
|
alembic/versions/0b01b44ab005_initial.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""initial
|
| 2 |
+
|
| 3 |
+
Revision ID: 0b01b44ab005
|
| 4 |
+
Revises:
|
| 5 |
+
Create Date: 2025-10-08 12:13:39.666969
|
| 6 |
+
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from typing import Sequence, Union
|
| 10 |
+
|
| 11 |
+
from alembic import op
|
| 12 |
+
import sqlalchemy as sa
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
# revision identifiers, used by Alembic.
|
| 16 |
+
revision: str = "0b01b44ab005"
|
| 17 |
+
down_revision: Union[str, Sequence[str], None] = None
|
| 18 |
+
branch_labels: Union[str, Sequence[str], None] = None
|
| 19 |
+
depends_on: Union[str, Sequence[str], None] = None
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def upgrade() -> None:
|
| 23 |
+
"""Upgrade schema."""
|
| 24 |
+
# ### commands auto generated by Alembic - please adjust! ###
|
| 25 |
+
op.create_table(
|
| 26 |
+
"config",
|
| 27 |
+
sa.Column("key", sa.String(), nullable=False),
|
| 28 |
+
sa.Column("value", sa.String(), nullable=False),
|
| 29 |
+
sa.PrimaryKeyConstraint("key"),
|
| 30 |
+
)
|
| 31 |
+
op.create_table(
|
| 32 |
+
"requests",
|
| 33 |
+
sa.Column("participant_id", sa.String(), nullable=False),
|
| 34 |
+
sa.Column("status", sa.String(), nullable=False),
|
| 35 |
+
sa.Column(
|
| 36 |
+
"timestamp",
|
| 37 |
+
sa.DateTime(),
|
| 38 |
+
server_default=sa.text("(CURRENT_TIMESTAMP)"),
|
| 39 |
+
nullable=False,
|
| 40 |
+
),
|
| 41 |
+
sa.Column("partition_id", sa.Integer(), nullable=True),
|
| 42 |
+
sa.Column("email", sa.String(), nullable=True),
|
| 43 |
+
sa.Column("hf_handle", sa.String(), nullable=True),
|
| 44 |
+
sa.Column("activation_code", sa.String(), nullable=True),
|
| 45 |
+
sa.Column("is_activated", sa.Integer(), nullable=False),
|
| 46 |
+
sa.PrimaryKeyConstraint("participant_id"),
|
| 47 |
+
)
|
| 48 |
+
# ### end Alembic commands ###
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def downgrade() -> None:
|
| 52 |
+
"""Downgrade schema."""
|
| 53 |
+
# ### commands auto generated by Alembic - please adjust! ###
|
| 54 |
+
op.drop_table("requests")
|
| 55 |
+
op.drop_table("config")
|
| 56 |
+
# ### end Alembic commands ###
|
app.py
DELETED
|
@@ -1,5 +0,0 @@
|
|
| 1 |
-
from blossomtune_gradio.gradio_app import demo
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
if __name__ == "__main__":
|
| 5 |
-
demo.launch()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
blossomtune_gradio/__main__.py
CHANGED
|
@@ -1,5 +1,9 @@
|
|
|
|
|
|
|
|
| 1 |
from blossomtune_gradio.gradio_app import demo
|
| 2 |
|
| 3 |
|
| 4 |
if __name__ == "__main__":
|
|
|
|
|
|
|
| 5 |
demo.launch()
|
|
|
|
| 1 |
+
from blossomtune_gradio import config as cfg
|
| 2 |
+
from blossomtune_gradio import database as db
|
| 3 |
from blossomtune_gradio.gradio_app import demo
|
| 4 |
|
| 5 |
|
| 6 |
if __name__ == "__main__":
|
| 7 |
+
if cfg.RUN_MIGRATIONS_ON_STARTUP:
|
| 8 |
+
db.run_migrations()
|
| 9 |
demo.launch()
|
blossomtune_gradio/config.py
CHANGED
|
@@ -13,6 +13,7 @@ SPACE_OWNER = os.getenv("SPACE_OWNER", SPACE_ID.split("/")[0] if SPACE_ID else N
|
|
| 13 |
DB_PATH = (
|
| 14 |
"/data/db/federation.db" if os.path.isdir("/data/db") else "./data/db/federation.db"
|
| 15 |
)
|
|
|
|
| 16 |
MAX_NUM_NODES = int(os.getenv("MAX_NUM_NODES", "20"))
|
| 17 |
SMTP_SENDER = os.getenv("SMTP_SENDER", "hello@ethicalabs.ai")
|
| 18 |
SMTP_SERVER = os.getenv("SMTP_SERVER", "localhost")
|
|
@@ -25,6 +26,9 @@ SUPERLINK_HOST = os.getenv("SUPERLINK_HOST", "127.0.0.1:9092")
|
|
| 25 |
SUPERLINK_PORT = int(os.getenv("SUPERLINK_PORT", 9092))
|
| 26 |
SUPERLINK_CONTROL_API_PORT = int(os.getenv("SUPERLINK_CONTROL_API_PORT", 9093))
|
| 27 |
SUPERLINK_MODE = os.getenv("SUPERLINK_MODE", "internal").lower() # Or external
|
|
|
|
|
|
|
|
|
|
| 28 |
|
| 29 |
# TLS root cert path. For production only.
|
| 30 |
TLS_CERT_DIR = os.getenv("TLS_CERT_DIR", "./certs/")
|
|
|
|
| 13 |
DB_PATH = (
|
| 14 |
"/data/db/federation.db" if os.path.isdir("/data/db") else "./data/db/federation.db"
|
| 15 |
)
|
| 16 |
+
SQLALCHEMY_URL = f"sqlite:///{os.path.abspath(DB_PATH)}"
|
| 17 |
MAX_NUM_NODES = int(os.getenv("MAX_NUM_NODES", "20"))
|
| 18 |
SMTP_SENDER = os.getenv("SMTP_SENDER", "hello@ethicalabs.ai")
|
| 19 |
SMTP_SERVER = os.getenv("SMTP_SERVER", "localhost")
|
|
|
|
| 26 |
SUPERLINK_PORT = int(os.getenv("SUPERLINK_PORT", 9092))
|
| 27 |
SUPERLINK_CONTROL_API_PORT = int(os.getenv("SUPERLINK_CONTROL_API_PORT", 9093))
|
| 28 |
SUPERLINK_MODE = os.getenv("SUPERLINK_MODE", "internal").lower() # Or external
|
| 29 |
+
RUN_MIGRATIONS_ON_STARTUP = util.strtobool(
|
| 30 |
+
os.getenv("RUN_MIGRATIONS_ON_STARTUP", "true")
|
| 31 |
+
) # Set to false in prod.
|
| 32 |
|
| 33 |
# TLS root cert path. For production only.
|
| 34 |
TLS_CERT_DIR = os.getenv("TLS_CERT_DIR", "./certs/")
|
blossomtune_gradio/database.py
CHANGED
|
@@ -1,34 +1,65 @@
|
|
| 1 |
-
import
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
from blossomtune_gradio import config as cfg
|
| 3 |
|
| 4 |
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
"
|
| 32 |
-
("num_partitions", "10"),
|
| 33 |
)
|
| 34 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from alembic import config
|
| 2 |
+
from sqlalchemy import create_engine, Column, String, Integer, DateTime, func
|
| 3 |
+
from sqlalchemy.orm import sessionmaker, declarative_base
|
| 4 |
+
|
| 5 |
+
|
| 6 |
from blossomtune_gradio import config as cfg
|
| 7 |
|
| 8 |
|
| 9 |
+
Base = declarative_base()
|
| 10 |
+
engine = create_engine(cfg.SQLALCHEMY_URL)
|
| 11 |
+
|
| 12 |
+
# The sessionmaker factory generates new Session objects when called.
|
| 13 |
+
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class Request(Base):
|
| 17 |
+
"""
|
| 18 |
+
SQLAlchemy model for the 'requests' table.
|
| 19 |
+
This table stores information about participants wanting to join the federation.
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
__tablename__ = "requests"
|
| 23 |
+
|
| 24 |
+
participant_id = Column(String, primary_key=True)
|
| 25 |
+
status = Column(String, nullable=False, default="pending")
|
| 26 |
+
timestamp = Column(DateTime, nullable=False, server_default=func.now())
|
| 27 |
+
partition_id = Column(Integer, nullable=True)
|
| 28 |
+
email = Column(String, nullable=True)
|
| 29 |
+
hf_handle = Column(String, nullable=True)
|
| 30 |
+
activation_code = Column(String, nullable=True)
|
| 31 |
+
is_activated = Column(Integer, nullable=False, default=0)
|
| 32 |
+
|
| 33 |
+
def __repr__(self):
|
| 34 |
+
return (
|
| 35 |
+
f"<Request(participant_id='{self.participant_id}', status='{self.status}')>"
|
|
|
|
| 36 |
)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class Config(Base):
|
| 40 |
+
"""
|
| 41 |
+
SQLAlchemy model for the 'config' table.
|
| 42 |
+
A simple key-value store for application settings.
|
| 43 |
+
"""
|
| 44 |
+
|
| 45 |
+
__tablename__ = "config"
|
| 46 |
+
|
| 47 |
+
key = Column(String, primary_key=True)
|
| 48 |
+
value = Column(String, nullable=False)
|
| 49 |
+
|
| 50 |
+
def __repr__(self):
|
| 51 |
+
return f"<Config(key='{self.key}', value='{self.value}')>"
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def run_migrations():
|
| 55 |
+
"""
|
| 56 |
+
Applies any pending Alembic migrations to the database.
|
| 57 |
+
This should be called on application startup.
|
| 58 |
+
"""
|
| 59 |
+
print("Running database migrations...")
|
| 60 |
+
alembicArgs = [
|
| 61 |
+
"--raiseerr",
|
| 62 |
+
"upgrade",
|
| 63 |
+
"head",
|
| 64 |
+
]
|
| 65 |
+
config.main(argv=alembicArgs)
|
blossomtune_gradio/federation.py
CHANGED
|
@@ -1,13 +1,11 @@
|
|
| 1 |
import string
|
| 2 |
import secrets
|
| 3 |
-
import sqlite3
|
| 4 |
-
|
| 5 |
-
from datetime import datetime
|
| 6 |
|
| 7 |
from blossomtune_gradio import config as cfg
|
| 8 |
from blossomtune_gradio import mail
|
| 9 |
from blossomtune_gradio import util
|
| 10 |
from blossomtune_gradio.settings import settings
|
|
|
|
| 11 |
|
| 12 |
|
| 13 |
def generate_participant_id(length=6):
|
|
@@ -23,105 +21,99 @@ def generate_activation_code(length=8):
|
|
| 23 |
|
| 24 |
|
| 25 |
def check_participant_status(pid_to_check: str, email: str, activation_code: str):
|
| 26 |
-
"""
|
| 27 |
-
|
| 28 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
if activation_code:
|
| 30 |
-
|
| 31 |
-
"SELECT participant_id, status, partition_id, is_activated, activation_code FROM requests WHERE hf_handle = ? AND email = ? AND activation_code = ?",
|
| 32 |
-
(pid_to_check, email, activation_code),
|
| 33 |
-
)
|
| 34 |
-
else:
|
| 35 |
-
cursor.execute(
|
| 36 |
-
"SELECT participant_id, status, partition_id, is_activated, activation_code FROM requests WHERE hf_handle = ? AND email = ?",
|
| 37 |
-
(pid_to_check, email),
|
| 38 |
-
)
|
| 39 |
-
result = cursor.fetchone()
|
| 40 |
|
| 41 |
-
|
| 42 |
-
num_partitions_res = cursor.fetchone()
|
| 43 |
-
num_partitions = num_partitions_res[0] if num_partitions_res else "10"
|
| 44 |
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
if approved_count >= cfg.MAX_NUM_NODES:
|
| 57 |
return (False, settings.get_text("federation_full_md"), None)
|
| 58 |
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
datetime.utcnow().isoformat(),
|
| 70 |
-
pid_to_check,
|
| 71 |
-
email,
|
| 72 |
-
new_activation_code,
|
| 73 |
-
0,
|
| 74 |
-
),
|
| 75 |
-
)
|
| 76 |
-
return (True, settings.get_text("registration_submitted_md"), None)
|
| 77 |
-
else:
|
| 78 |
-
return (False, message, None)
|
| 79 |
-
|
| 80 |
-
# Existing user
|
| 81 |
-
participant_id, status, partition_id, is_activated, stored_code = result
|
| 82 |
-
|
| 83 |
-
# Case 2: User is activating their account
|
| 84 |
-
if not is_activated:
|
| 85 |
-
if activation_code == stored_code:
|
| 86 |
-
with sqlite3.connect(cfg.DB_PATH) as conn:
|
| 87 |
-
conn.execute(
|
| 88 |
-
"UPDATE requests SET is_activated = 1 WHERE hf_handle = ?",
|
| 89 |
-
(pid_to_check,),
|
| 90 |
)
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 95 |
if not activation_code:
|
| 96 |
-
return (False, settings.get_text("missing_activation_code_md"))
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
|
|
|
| 125 |
|
| 126 |
|
| 127 |
def manage_request(participant_id: str, partition_id: str, action: str):
|
|
@@ -129,29 +121,31 @@ def manage_request(participant_id: str, partition_id: str, action: str):
|
|
| 129 |
if not participant_id:
|
| 130 |
return False, "Please select a participant from the pending requests table."
|
| 131 |
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
|
|
|
|
|
|
|
|
|
| 135 |
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
)
|
| 143 |
-
is_activated_res = cursor.fetchone()
|
| 144 |
-
if not is_activated_res or not is_activated_res[0]:
|
| 145 |
return (
|
| 146 |
False,
|
| 147 |
settings.get_text("participant_not_activated_warning_md"),
|
| 148 |
)
|
| 149 |
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
(p_id_int,)
|
|
|
|
| 153 |
)
|
| 154 |
-
|
|
|
|
| 155 |
return (
|
| 156 |
False,
|
| 157 |
settings.get_text(
|
|
@@ -159,20 +153,17 @@ def manage_request(participant_id: str, partition_id: str, action: str):
|
|
| 159 |
),
|
| 160 |
)
|
| 161 |
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
)
|
| 166 |
return (
|
| 167 |
True,
|
| 168 |
f"Participant {participant_id} is allowed to join the federation.",
|
| 169 |
)
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
("denied", participant_id),
|
| 175 |
-
)
|
| 176 |
return (
|
| 177 |
True,
|
| 178 |
f"Participant {participant_id} is not allowed to join the federation.",
|
|
@@ -180,12 +171,14 @@ def manage_request(participant_id: str, partition_id: str, action: str):
|
|
| 180 |
|
| 181 |
|
| 182 |
def get_next_partion_id() -> int:
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
|
|
|
|
|
|
| 187 |
)
|
| 188 |
-
used_ids = {row[0] for row in
|
| 189 |
|
| 190 |
next_id = 0
|
| 191 |
while next_id in used_ids:
|
|
|
|
| 1 |
import string
|
| 2 |
import secrets
|
|
|
|
|
|
|
|
|
|
| 3 |
|
| 4 |
from blossomtune_gradio import config as cfg
|
| 5 |
from blossomtune_gradio import mail
|
| 6 |
from blossomtune_gradio import util
|
| 7 |
from blossomtune_gradio.settings import settings
|
| 8 |
+
from blossomtune_gradio.database import SessionLocal, Request, Config
|
| 9 |
|
| 10 |
|
| 11 |
def generate_participant_id(length=6):
|
|
|
|
| 21 |
|
| 22 |
|
| 23 |
def check_participant_status(pid_to_check: str, email: str, activation_code: str):
|
| 24 |
+
"""
|
| 25 |
+
Handles a participant's request to join, activate, or check status using SQLAlchemy.
|
| 26 |
+
Returns a tuple: (is_approved: bool, message: str, data: any | None)
|
| 27 |
+
The 'is_approved' boolean is True ONLY when the participant's final status is 'approved'.
|
| 28 |
+
"""
|
| 29 |
+
with SessionLocal() as db:
|
| 30 |
+
query = db.query(Request).filter(
|
| 31 |
+
Request.hf_handle == pid_to_check, Request.email == email
|
| 32 |
+
)
|
| 33 |
if activation_code:
|
| 34 |
+
query = query.filter(Request.activation_code == activation_code)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
|
| 36 |
+
request = query.first()
|
|
|
|
|
|
|
| 37 |
|
| 38 |
+
num_partitions_config = (
|
| 39 |
+
db.query(Config).filter(Config.key == "num_partitions").first()
|
| 40 |
+
)
|
| 41 |
+
num_partitions = num_partitions_config.value if num_partitions_config else "10"
|
| 42 |
+
|
| 43 |
+
# Case 1: New user registration
|
| 44 |
+
if request is None:
|
| 45 |
+
if activation_code:
|
| 46 |
+
return (False, settings.get_text("activation_invalid_md"), None)
|
| 47 |
+
if not util.validate_email(email):
|
| 48 |
+
return (False, settings.get_text("invalid_email_md"), None)
|
| 49 |
+
|
| 50 |
+
approved_count = (
|
| 51 |
+
db.query(Request).filter(Request.status == "approved").count()
|
| 52 |
+
)
|
| 53 |
if approved_count >= cfg.MAX_NUM_NODES:
|
| 54 |
return (False, settings.get_text("federation_full_md"), None)
|
| 55 |
|
| 56 |
+
participant_id = generate_participant_id()
|
| 57 |
+
new_activation_code = generate_activation_code()
|
| 58 |
+
mail_sent, message = mail.send_activation_email(email, new_activation_code)
|
| 59 |
+
|
| 60 |
+
if mail_sent:
|
| 61 |
+
new_request = Request(
|
| 62 |
+
participant_id=participant_id,
|
| 63 |
+
hf_handle=pid_to_check,
|
| 64 |
+
email=email,
|
| 65 |
+
activation_code=new_activation_code,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
)
|
| 67 |
+
db.add(new_request)
|
| 68 |
+
db.commit()
|
| 69 |
+
# A successful registration step, but not yet approved for federation.
|
| 70 |
+
return (False, settings.get_text("registration_submitted_md"), None)
|
| 71 |
+
else:
|
| 72 |
+
return (False, message, None)
|
| 73 |
+
|
| 74 |
+
# Case 2: User is activating their account
|
| 75 |
+
if not request.is_activated:
|
| 76 |
+
if activation_code == request.activation_code:
|
| 77 |
+
request.is_activated = 1
|
| 78 |
+
db.commit()
|
| 79 |
+
# A successful activation step, but not yet approved.
|
| 80 |
+
return (False, settings.get_text("activation_successful_md"), None)
|
| 81 |
+
else:
|
| 82 |
+
return (False, settings.get_text("activation_invalid_md"), None)
|
| 83 |
+
|
| 84 |
+
# At this point, user is activated.
|
| 85 |
+
# They must provide the activation code to check their final status.
|
| 86 |
if not activation_code:
|
| 87 |
+
return (False, settings.get_text("missing_activation_code_md"), None)
|
| 88 |
+
|
| 89 |
+
# Case 3: Activated user is checking their final status
|
| 90 |
+
if request.status == "approved":
|
| 91 |
+
hostname = (
|
| 92 |
+
"localhost"
|
| 93 |
+
if not cfg.SPACE_ID
|
| 94 |
+
else f"{cfg.SPACE_ID.split('/')[1]}-{cfg.SPACE_ID.split('/')[0]}.hf.space"
|
| 95 |
+
)
|
| 96 |
+
superlink_hostname = cfg.SUPERLINK_HOST or hostname
|
| 97 |
+
|
| 98 |
+
connection_string = settings.get_text(
|
| 99 |
+
"status_approved_md",
|
| 100 |
+
participant_id=request.participant_id,
|
| 101 |
+
partition_id=request.partition_id,
|
| 102 |
+
superlink_hostname=superlink_hostname,
|
| 103 |
+
num_partitions=num_partitions,
|
| 104 |
+
)
|
| 105 |
+
# The user is fully approved. Return success and the cert path.
|
| 106 |
+
return (True, connection_string, cfg.BLOSSOMTUNE_TLS_CERT_PATH)
|
| 107 |
+
elif request.status == "pending":
|
| 108 |
+
return (False, settings.get_text("status_pending_md"), None)
|
| 109 |
+
else: # Denied
|
| 110 |
+
return (
|
| 111 |
+
False,
|
| 112 |
+
settings.get_text(
|
| 113 |
+
"status_denied_md", participant_id=request.participant_id
|
| 114 |
+
),
|
| 115 |
+
None,
|
| 116 |
+
)
|
| 117 |
|
| 118 |
|
| 119 |
def manage_request(participant_id: str, partition_id: str, action: str):
|
|
|
|
| 121 |
if not participant_id:
|
| 122 |
return False, "Please select a participant from the pending requests table."
|
| 123 |
|
| 124 |
+
with SessionLocal() as db:
|
| 125 |
+
request = (
|
| 126 |
+
db.query(Request).filter(Request.participant_id == participant_id).first()
|
| 127 |
+
)
|
| 128 |
+
if not request:
|
| 129 |
+
return False, "Participant not found."
|
| 130 |
|
| 131 |
+
if action == "approve":
|
| 132 |
+
if not partition_id or not partition_id.isdigit():
|
| 133 |
+
return False, "Please provide a valid integer for the Partition ID."
|
| 134 |
+
|
| 135 |
+
p_id_int = int(partition_id)
|
| 136 |
+
if not request.is_activated:
|
|
|
|
|
|
|
|
|
|
| 137 |
return (
|
| 138 |
False,
|
| 139 |
settings.get_text("participant_not_activated_warning_md"),
|
| 140 |
)
|
| 141 |
|
| 142 |
+
existing_participant = (
|
| 143 |
+
db.query(Request)
|
| 144 |
+
.filter(Request.partition_id == p_id_int, Request.status == "approved")
|
| 145 |
+
.first()
|
| 146 |
)
|
| 147 |
+
|
| 148 |
+
if existing_participant:
|
| 149 |
return (
|
| 150 |
False,
|
| 151 |
settings.get_text(
|
|
|
|
| 153 |
),
|
| 154 |
)
|
| 155 |
|
| 156 |
+
request.status = "approved"
|
| 157 |
+
request.partition_id = p_id_int
|
| 158 |
+
db.commit()
|
|
|
|
| 159 |
return (
|
| 160 |
True,
|
| 161 |
f"Participant {participant_id} is allowed to join the federation.",
|
| 162 |
)
|
| 163 |
+
else: # Deny
|
| 164 |
+
request.status = "denied"
|
| 165 |
+
request.partition_id = None
|
| 166 |
+
db.commit()
|
|
|
|
|
|
|
| 167 |
return (
|
| 168 |
True,
|
| 169 |
f"Participant {participant_id} is not allowed to join the federation.",
|
|
|
|
| 171 |
|
| 172 |
|
| 173 |
def get_next_partion_id() -> int:
|
| 174 |
+
"""Finds the lowest available partition ID."""
|
| 175 |
+
with SessionLocal() as db:
|
| 176 |
+
used_ids_query = (
|
| 177 |
+
db.query(Request.partition_id)
|
| 178 |
+
.filter(Request.status == "approved", Request.partition_id.isnot(None))
|
| 179 |
+
.all()
|
| 180 |
)
|
| 181 |
+
used_ids = {row[0] for row in used_ids_query}
|
| 182 |
|
| 183 |
next_id = 0
|
| 184 |
while next_id in used_ids:
|
blossomtune_gradio/gradio_app.py
CHANGED
|
@@ -1,12 +1,10 @@
|
|
| 1 |
import gradio as gr
|
| 2 |
|
| 3 |
-
from blossomtune_gradio import database as db
|
| 4 |
from blossomtune_gradio.ui import components
|
| 5 |
from blossomtune_gradio.ui import callbacks
|
| 6 |
|
| 7 |
|
| 8 |
with gr.Blocks(theme=gr.themes.Soft(), title="Flower Superlink & Runner") as demo:
|
| 9 |
-
db.init()
|
| 10 |
gr.Markdown("BlossomTune 🌸 Flower Superlink & Runner")
|
| 11 |
with gr.Row():
|
| 12 |
login_button = gr.LoginButton()
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
|
|
|
|
| 3 |
from blossomtune_gradio.ui import components
|
| 4 |
from blossomtune_gradio.ui import callbacks
|
| 5 |
|
| 6 |
|
| 7 |
with gr.Blocks(theme=gr.themes.Soft(), title="Flower Superlink & Runner") as demo:
|
|
|
|
| 8 |
gr.Markdown("BlossomTune 🌸 Flower Superlink & Runner")
|
| 9 |
with gr.Row():
|
| 10 |
login_button = gr.LoginButton()
|
blossomtune_gradio/processing.py
CHANGED
|
@@ -1,11 +1,12 @@
|
|
| 1 |
import os
|
| 2 |
import shutil
|
| 3 |
-
import sqlite3
|
| 4 |
import threading
|
| 5 |
import subprocess
|
| 6 |
|
| 7 |
from blossomtune_gradio.logs import log
|
| 8 |
from blossomtune_gradio import config as cfg
|
|
|
|
|
|
|
| 9 |
|
| 10 |
|
| 11 |
# In-memory store for background processes and logs
|
|
@@ -44,7 +45,9 @@ def start_superlink():
|
|
| 44 |
|
| 45 |
if process_store["superlink"] and process_store["superlink"].poll() is None:
|
| 46 |
return False, "Superlink process is already running."
|
| 47 |
-
|
|
|
|
|
|
|
| 48 |
threading.Thread(
|
| 49 |
target=run_process, args=(command, "superlink"), daemon=True
|
| 50 |
).start()
|
|
@@ -58,31 +61,36 @@ def start_runner(
|
|
| 58 |
):
|
| 59 |
if process_store["runner"] and process_store["runner"].poll() is None:
|
| 60 |
return False, "A Runner process is already running."
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
|
|
|
|
|
|
| 65 |
return (
|
| 66 |
False,
|
| 67 |
-
"Superlink is not running. Please start it before starting the runner.",
|
| 68 |
)
|
| 69 |
-
|
| 70 |
if not all([runner_app, run_id, num_partitions]):
|
| 71 |
return False, "Please provide a Runner App, Run ID, and Total Partitions."
|
| 72 |
if not num_partitions.isdigit() or int(num_partitions) <= 0:
|
| 73 |
return False, "Total Partitions must be a positive integer."
|
| 74 |
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
|
|
|
|
|
|
| 81 |
|
| 82 |
runner_app_path = runner_app.replace(".", os.path.sep)
|
| 83 |
if not os.path.exists(runner_app_path):
|
| 84 |
return False, f"Unable to find app path '{runner_app_path}'."
|
| 85 |
|
|
|
|
| 86 |
command = [
|
| 87 |
shutil.which("flwr"),
|
| 88 |
"run",
|
|
|
|
| 1 |
import os
|
| 2 |
import shutil
|
|
|
|
| 3 |
import threading
|
| 4 |
import subprocess
|
| 5 |
|
| 6 |
from blossomtune_gradio.logs import log
|
| 7 |
from blossomtune_gradio import config as cfg
|
| 8 |
+
from blossomtune_gradio import util
|
| 9 |
+
from blossomtune_gradio.database import SessionLocal, Config
|
| 10 |
|
| 11 |
|
| 12 |
# In-memory store for background processes and logs
|
|
|
|
| 45 |
|
| 46 |
if process_store["superlink"] and process_store["superlink"].poll() is None:
|
| 47 |
return False, "Superlink process is already running."
|
| 48 |
+
|
| 49 |
+
# The command needs to be adapted for TLS if it's not insecure
|
| 50 |
+
command = [shutil.which("flower-superlink"), "--insecure"] # Placeholder
|
| 51 |
threading.Thread(
|
| 52 |
target=run_process, args=(command, "superlink"), daemon=True
|
| 53 |
).start()
|
|
|
|
| 61 |
):
|
| 62 |
if process_store["runner"] and process_store["runner"].poll() is None:
|
| 63 |
return False, "A Runner process is already running."
|
| 64 |
+
|
| 65 |
+
# Check if the Superlink is running, respecting the configured mode
|
| 66 |
+
if cfg.SUPERLINK_MODE == "external":
|
| 67 |
+
if not util.is_port_open(cfg.SUPERLINK_HOST, cfg.SUPERLINK_PORT):
|
| 68 |
+
return False, "External Superlink is not running or unreachable."
|
| 69 |
+
elif not (process_store["superlink"] and process_store["superlink"].poll() is None):
|
| 70 |
return (
|
| 71 |
False,
|
| 72 |
+
"Internal Superlink is not running. Please start it before starting the runner.",
|
| 73 |
)
|
| 74 |
+
|
| 75 |
if not all([runner_app, run_id, num_partitions]):
|
| 76 |
return False, "Please provide a Runner App, Run ID, and Total Partitions."
|
| 77 |
if not num_partitions.isdigit() or int(num_partitions) <= 0:
|
| 78 |
return False, "Total Partitions must be a positive integer."
|
| 79 |
|
| 80 |
+
# Update the number of partitions in the database using SQLAlchemy
|
| 81 |
+
with SessionLocal() as db:
|
| 82 |
+
config_entry = db.query(Config).filter(Config.key == "num_partitions").first()
|
| 83 |
+
if config_entry:
|
| 84 |
+
config_entry.value = num_partitions
|
| 85 |
+
else:
|
| 86 |
+
db.add(Config(key="num_partitions", value=num_partitions))
|
| 87 |
+
db.commit()
|
| 88 |
|
| 89 |
runner_app_path = runner_app.replace(".", os.path.sep)
|
| 90 |
if not os.path.exists(runner_app_path):
|
| 91 |
return False, f"Unable to find app path '{runner_app_path}'."
|
| 92 |
|
| 93 |
+
# Construct the command for a TLS-enabled runner
|
| 94 |
command = [
|
| 95 |
shutil.which("flwr"),
|
| 96 |
"run",
|
blossomtune_gradio/ui/callbacks.py
CHANGED
|
@@ -1,5 +1,4 @@
|
|
| 1 |
import time
|
| 2 |
-
import sqlite3
|
| 3 |
import gradio as gr
|
| 4 |
import pandas as pd
|
| 5 |
|
|
@@ -9,6 +8,7 @@ from blossomtune_gradio import federation as fed
|
|
| 9 |
from blossomtune_gradio import processing
|
| 10 |
from blossomtune_gradio.settings import settings
|
| 11 |
from blossomtune_gradio import util
|
|
|
|
| 12 |
|
| 13 |
from . import components
|
| 14 |
from . import auth
|
|
@@ -46,16 +46,31 @@ def get_full_status_update(
|
|
| 46 |
else:
|
| 47 |
auth_status = settings.get_text("auth_status_local_mode_md")
|
| 48 |
|
| 49 |
-
with
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
|
| 57 |
# Superlink Status Logic
|
| 58 |
-
superlink_btn_update = gr.update()
|
| 59 |
|
| 60 |
if cfg.SUPERLINK_MODE == "internal":
|
| 61 |
superlink_is_running = (
|
|
@@ -78,7 +93,6 @@ def get_full_status_update(
|
|
| 78 |
else:
|
| 79 |
is_open = util.is_port_open(cfg.SUPERLINK_HOST, cfg.SUPERLINK_PORT)
|
| 80 |
superlink_status = "🟢 Running" if is_open else "🔴 Not Running"
|
| 81 |
-
# Disable the button in external mode
|
| 82 |
superlink_btn_update = gr.update(value="Managed Externally", interactive=False)
|
| 83 |
else:
|
| 84 |
superlink_status = "⚠️ Invalid Mode"
|
|
@@ -126,7 +140,6 @@ def toggle_superlink(
|
|
| 126 |
):
|
| 127 |
"""Toggles the Superlink process on or off."""
|
| 128 |
if not auth.is_space_owner(profile, oauth_token):
|
| 129 |
-
# Hardcode warning text as it's not in the schema
|
| 130 |
gr.Warning("You are not authorized to perform this operation.")
|
| 131 |
return
|
| 132 |
if (
|
|
@@ -147,7 +160,6 @@ def toggle_runner(
|
|
| 147 |
):
|
| 148 |
"""Toggles the Runner process on or off."""
|
| 149 |
if not auth.is_space_owner(profile, oauth_token):
|
| 150 |
-
# Hardcode warning text as it's not in the schema
|
| 151 |
gr.Warning("You are not authorized to perform this operation.")
|
| 152 |
return
|
| 153 |
if (
|
|
@@ -204,8 +216,8 @@ def on_check_participant_status(
|
|
| 204 |
pid_to_check = user_hf_handle.strip()
|
| 205 |
email_to_add = email.strip()
|
| 206 |
activation_code_to_check = activation_code.strip()
|
| 207 |
-
|
| 208 |
-
|
| 209 |
pid_to_check, email_to_add, activation_code_to_check
|
| 210 |
)
|
| 211 |
return {
|
|
@@ -217,7 +229,6 @@ def on_check_participant_status(
|
|
| 217 |
|
| 218 |
|
| 219 |
def on_manage_fed_request(participant_id: str, partition_id: str, action: str):
|
| 220 |
-
# The federation module is responsible for getting the correct text from settings
|
| 221 |
result, message = fed.manage_request(participant_id, partition_id, action)
|
| 222 |
if result:
|
| 223 |
gr.Info(message)
|
|
|
|
| 1 |
import time
|
|
|
|
| 2 |
import gradio as gr
|
| 3 |
import pandas as pd
|
| 4 |
|
|
|
|
| 8 |
from blossomtune_gradio import processing
|
| 9 |
from blossomtune_gradio.settings import settings
|
| 10 |
from blossomtune_gradio import util
|
| 11 |
+
from blossomtune_gradio.database import SessionLocal, Request
|
| 12 |
|
| 13 |
from . import components
|
| 14 |
from . import auth
|
|
|
|
| 46 |
else:
|
| 47 |
auth_status = settings.get_text("auth_status_local_mode_md")
|
| 48 |
|
| 49 |
+
with SessionLocal() as db:
|
| 50 |
+
pending_results = (
|
| 51 |
+
db.query(Request.participant_id, Request.hf_handle, Request.email)
|
| 52 |
+
.filter(Request.status == "pending", Request.is_activated == 1)
|
| 53 |
+
.order_by(Request.timestamp.asc())
|
| 54 |
+
.all()
|
| 55 |
+
)
|
| 56 |
+
approved_results = (
|
| 57 |
+
db.query(
|
| 58 |
+
Request.participant_id,
|
| 59 |
+
Request.hf_handle,
|
| 60 |
+
Request.email,
|
| 61 |
+
Request.partition_id,
|
| 62 |
+
)
|
| 63 |
+
.filter(Request.status == "approved")
|
| 64 |
+
.order_by(Request.timestamp.desc())
|
| 65 |
+
.all()
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
# Convert SQLAlchemy rows to simple lists
|
| 69 |
+
pending_rows = [list(row) for row in pending_results]
|
| 70 |
+
approved_rows = [list(row) for row in approved_results]
|
| 71 |
|
| 72 |
# Superlink Status Logic
|
| 73 |
+
superlink_btn_update = gr.update()
|
| 74 |
|
| 75 |
if cfg.SUPERLINK_MODE == "internal":
|
| 76 |
superlink_is_running = (
|
|
|
|
| 93 |
else:
|
| 94 |
is_open = util.is_port_open(cfg.SUPERLINK_HOST, cfg.SUPERLINK_PORT)
|
| 95 |
superlink_status = "🟢 Running" if is_open else "🔴 Not Running"
|
|
|
|
| 96 |
superlink_btn_update = gr.update(value="Managed Externally", interactive=False)
|
| 97 |
else:
|
| 98 |
superlink_status = "⚠️ Invalid Mode"
|
|
|
|
| 140 |
):
|
| 141 |
"""Toggles the Superlink process on or off."""
|
| 142 |
if not auth.is_space_owner(profile, oauth_token):
|
|
|
|
| 143 |
gr.Warning("You are not authorized to perform this operation.")
|
| 144 |
return
|
| 145 |
if (
|
|
|
|
| 160 |
):
|
| 161 |
"""Toggles the Runner process on or off."""
|
| 162 |
if not auth.is_space_owner(profile, oauth_token):
|
|
|
|
| 163 |
gr.Warning("You are not authorized to perform this operation.")
|
| 164 |
return
|
| 165 |
if (
|
|
|
|
| 216 |
pid_to_check = user_hf_handle.strip()
|
| 217 |
email_to_add = email.strip()
|
| 218 |
activation_code_to_check = activation_code.strip()
|
| 219 |
+
|
| 220 |
+
approved, message, download = fed.check_participant_status(
|
| 221 |
pid_to_check, email_to_add, activation_code_to_check
|
| 222 |
)
|
| 223 |
return {
|
|
|
|
| 229 |
|
| 230 |
|
| 231 |
def on_manage_fed_request(participant_id: str, partition_id: str, action: str):
|
|
|
|
| 232 |
result, message = fed.manage_request(participant_id, partition_id, action)
|
| 233 |
if result:
|
| 234 |
gr.Info(message)
|
flower_apps/quickstart_huggingface/pyproject.toml
CHANGED
|
@@ -60,3 +60,4 @@ options.backend.client-resources.num-gpus = 0.0 # at most 4 ClientApp will run i
|
|
| 60 |
[tool.flwr.federations.local-deployment]
|
| 61 |
address = "0.0.0.0:9093"
|
| 62 |
insecure = true
|
|
|
|
|
|
| 60 |
[tool.flwr.federations.local-deployment]
|
| 61 |
address = "0.0.0.0:9093"
|
| 62 |
insecure = true
|
| 63 |
+
root-certificate = ""
|
pyproject.toml
CHANGED
|
@@ -16,6 +16,8 @@ dependencies = [
|
|
| 16 |
"markupsafe==2.1.3",
|
| 17 |
"jinja2>=3.1.6",
|
| 18 |
"mlx[cpu]>=0.29.2",
|
|
|
|
|
|
|
| 19 |
]
|
| 20 |
|
| 21 |
[tool.uv.sources]
|
|
@@ -90,3 +92,87 @@ build-backend = "setuptools.build_meta"
|
|
| 90 |
where = ["."]
|
| 91 |
include = ["blossomtune_gradio", "flower_apps"]
|
| 92 |
exclude = ["results", "data"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
"markupsafe==2.1.3",
|
| 17 |
"jinja2>=3.1.6",
|
| 18 |
"mlx[cpu]>=0.29.2",
|
| 19 |
+
"alembic>=1.16.5",
|
| 20 |
+
"sqlalchemy>=2.0.43",
|
| 21 |
]
|
| 22 |
|
| 23 |
[tool.uv.sources]
|
|
|
|
| 92 |
where = ["."]
|
| 93 |
include = ["blossomtune_gradio", "flower_apps"]
|
| 94 |
exclude = ["results", "data"]
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
[tool.alembic]
|
| 98 |
+
|
| 99 |
+
# path to migration scripts.
|
| 100 |
+
# this is typically a path given in POSIX (e.g. forward slashes)
|
| 101 |
+
# format, relative to the token %(here)s which refers to the location of this
|
| 102 |
+
# ini file
|
| 103 |
+
script_location = "%(here)s/alembic"
|
| 104 |
+
|
| 105 |
+
# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s
|
| 106 |
+
# Uncomment the line below if you want the files to be prepended with date and time
|
| 107 |
+
# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file
|
| 108 |
+
# for all available tokens
|
| 109 |
+
# file_template = "%%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s"
|
| 110 |
+
|
| 111 |
+
# additional paths to be prepended to sys.path. defaults to the current working directory.
|
| 112 |
+
prepend_sys_path = [
|
| 113 |
+
"."
|
| 114 |
+
]
|
| 115 |
+
|
| 116 |
+
# timezone to use when rendering the date within the migration file
|
| 117 |
+
# as well as the filename.
|
| 118 |
+
# If specified, requires the python>=3.9 or backports.zoneinfo library and tzdata library.
|
| 119 |
+
# Any required deps can installed by adding `alembic[tz]` to the pip requirements
|
| 120 |
+
# string value is passed to ZoneInfo()
|
| 121 |
+
# leave blank for localtime
|
| 122 |
+
# timezone =
|
| 123 |
+
|
| 124 |
+
# max length of characters to apply to the "slug" field
|
| 125 |
+
# truncate_slug_length = 40
|
| 126 |
+
|
| 127 |
+
# set to 'true' to run the environment during
|
| 128 |
+
# the 'revision' command, regardless of autogenerate
|
| 129 |
+
# revision_environment = false
|
| 130 |
+
|
| 131 |
+
# set to 'true' to allow .pyc and .pyo files without
|
| 132 |
+
# a source .py file to be detected as revisions in the
|
| 133 |
+
# versions/ directory
|
| 134 |
+
# sourceless = false
|
| 135 |
+
|
| 136 |
+
# version location specification; This defaults
|
| 137 |
+
# to <script_location>/versions. When using multiple version
|
| 138 |
+
# directories, initial revisions must be specified with --version-path.
|
| 139 |
+
# version_locations = [
|
| 140 |
+
# "%(here)s/alembic/versions",
|
| 141 |
+
# "%(here)s/foo/bar"
|
| 142 |
+
# ]
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
# set to 'true' to search source files recursively
|
| 146 |
+
# in each "version_locations" directory
|
| 147 |
+
# new in Alembic version 1.10
|
| 148 |
+
# recursive_version_locations = false
|
| 149 |
+
|
| 150 |
+
# the output encoding used when revision files
|
| 151 |
+
# are written from script.py.mako
|
| 152 |
+
# output_encoding = "utf-8"
|
| 153 |
+
|
| 154 |
+
# This section defines scripts or Python functions that are run
|
| 155 |
+
# on newly generated revision scripts. See the documentation for further
|
| 156 |
+
# detail and examples
|
| 157 |
+
# [[tool.alembic.post_write_hooks]]
|
| 158 |
+
# format using "black" - use the console_scripts runner,
|
| 159 |
+
# against the "black" entrypoint
|
| 160 |
+
# name = "black"
|
| 161 |
+
# type = "console_scripts"
|
| 162 |
+
# entrypoint = "black"
|
| 163 |
+
# options = "-l 79 REVISION_SCRIPT_FILENAME"
|
| 164 |
+
#
|
| 165 |
+
# [[tool.alembic.post_write_hooks]]
|
| 166 |
+
# lint with attempts to fix using "ruff" - use the module runner, against the "ruff" module
|
| 167 |
+
# name = "ruff"
|
| 168 |
+
# type = "module"
|
| 169 |
+
# module = "ruff"
|
| 170 |
+
# options = "check --fix REVISION_SCRIPT_FILENAME"
|
| 171 |
+
#
|
| 172 |
+
# [[tool.alembic.post_write_hooks]]
|
| 173 |
+
# Alternatively, use the exec runner to execute a binary found on your PATH
|
| 174 |
+
# name = "ruff"
|
| 175 |
+
# type = "exec"
|
| 176 |
+
# executable = "ruff"
|
| 177 |
+
# options = "check --fix REVISION_SCRIPT_FILENAME"
|
| 178 |
+
|
tests/conftest.py
CHANGED
|
@@ -1,18 +1,60 @@
|
|
| 1 |
import pytest
|
| 2 |
-
import
|
| 3 |
|
| 4 |
-
from
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
|
| 6 |
|
| 7 |
@pytest.fixture
|
| 8 |
-
def
|
| 9 |
"""
|
| 10 |
-
Fixture to set up
|
| 11 |
-
It
|
| 12 |
-
|
| 13 |
"""
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import pytest
|
| 2 |
+
from unittest.mock import MagicMock
|
| 3 |
|
| 4 |
+
from alembic.config import Config
|
| 5 |
+
from alembic import command
|
| 6 |
+
from sqlalchemy import create_engine
|
| 7 |
+
from sqlalchemy.orm import sessionmaker
|
| 8 |
+
|
| 9 |
+
from blossomtune_gradio import config
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@pytest.fixture(scope="session")
|
| 13 |
+
def alembic_config():
|
| 14 |
+
"""Fixture to create a valid Alembic Config object."""
|
| 15 |
+
return Config("alembic.ini")
|
| 16 |
|
| 17 |
|
| 18 |
@pytest.fixture
|
| 19 |
+
def db_session(mocker, tmp_path):
|
| 20 |
"""
|
| 21 |
+
Fixture to set up a clean file-based SQLite database for each test function.
|
| 22 |
+
It creates the database schema using Alembic programmatically and ensures all
|
| 23 |
+
modules use the same test database session.
|
| 24 |
"""
|
| 25 |
+
db_file = tmp_path / "test_federation.db"
|
| 26 |
+
db_url = f"sqlite:///{db_file}"
|
| 27 |
+
mocker.patch.object(config, "SQLALCHEMY_URL", db_url)
|
| 28 |
+
|
| 29 |
+
# Create an Alembic Config object and point it to the temp database.
|
| 30 |
+
alembic_cfg = Config()
|
| 31 |
+
alembic_cfg.set_main_option("script_location", "alembic")
|
| 32 |
+
alembic_cfg.set_main_option("sqlalchemy.url", db_url)
|
| 33 |
+
|
| 34 |
+
# Apply the migrations to create the schema in the temporary database.
|
| 35 |
+
command.upgrade(alembic_cfg, "head")
|
| 36 |
+
|
| 37 |
+
# Set up the SQLAlchemy engine and session factory for the tests to use.
|
| 38 |
+
engine = create_engine(db_url)
|
| 39 |
+
TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
| 40 |
+
session = TestingSessionLocal()
|
| 41 |
+
|
| 42 |
+
# Mock the SessionLocal factory in each module where it is imported and used.
|
| 43 |
+
mocker.patch("blossomtune_gradio.federation.SessionLocal", return_value=session)
|
| 44 |
+
mocker.patch("blossomtune_gradio.processing.SessionLocal", return_value=session)
|
| 45 |
+
mocker.patch("blossomtune_gradio.ui.callbacks.SessionLocal", return_value=session)
|
| 46 |
+
|
| 47 |
+
yield session
|
| 48 |
+
|
| 49 |
+
session.close()
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
@pytest.fixture
|
| 53 |
+
def mock_settings(mocker):
|
| 54 |
+
"""Fixture to mock the settings module, available to all tests."""
|
| 55 |
+
mock_get = MagicMock(
|
| 56 |
+
side_effect=lambda key, **kwargs: f"mock_{key}".format(**kwargs)
|
| 57 |
+
)
|
| 58 |
+
mocker.patch("blossomtune_gradio.federation.settings.get_text", mock_get)
|
| 59 |
+
mocker.patch("blossomtune_gradio.ui.callbacks.settings.get_text", mock_get)
|
| 60 |
+
return mock_get
|
tests/test_federation.py
CHANGED
|
@@ -1,18 +1,8 @@
|
|
| 1 |
import pytest
|
| 2 |
-
from
|
| 3 |
|
| 4 |
from blossomtune_gradio import federation as fed
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
@pytest.fixture
|
| 8 |
-
def mock_settings(mocker):
|
| 9 |
-
"""Fixture to mock the settings module."""
|
| 10 |
-
# The lambda returns a formatted string to simulate Jinja2's behavior
|
| 11 |
-
mock_get = MagicMock(
|
| 12 |
-
side_effect=lambda key, **kwargs: f"mock_{key}".format(**kwargs)
|
| 13 |
-
)
|
| 14 |
-
mocker.patch("blossomtune_gradio.federation.settings.get_text", mock_get)
|
| 15 |
-
return mock_get
|
| 16 |
|
| 17 |
|
| 18 |
@pytest.fixture
|
|
@@ -38,208 +28,202 @@ def test_generate_activation_code():
|
|
| 38 |
|
| 39 |
|
| 40 |
class TestCheckParticipantStatus:
|
| 41 |
-
"""Test suite for the check_participant_status function."""
|
| 42 |
|
| 43 |
-
def test_new_user_registration_success(
|
| 44 |
-
self, in_memory_db, mock_settings, mock_mail
|
| 45 |
-
):
|
| 46 |
"""Verify successful registration for a new user."""
|
| 47 |
mock_mail.return_value = (True, "")
|
| 48 |
-
|
| 49 |
"new_user", "new@example.com", ""
|
| 50 |
)
|
| 51 |
-
assert
|
| 52 |
assert download is None
|
| 53 |
assert message == "mock_registration_submitted_md"
|
| 54 |
|
| 55 |
# Verify the user was added to the database
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
assert
|
| 59 |
|
| 60 |
-
def test_new_user_invalid_email(self,
|
| 61 |
"""Verify registration fails with an invalid email."""
|
| 62 |
-
|
| 63 |
"user", "invalid-email", ""
|
| 64 |
)
|
| 65 |
-
assert
|
| 66 |
assert download is None
|
| 67 |
assert message == "mock_invalid_email_md"
|
| 68 |
|
| 69 |
-
def test_new_user_federation_full(self,
|
| 70 |
"""Verify registration fails when the federation is full."""
|
| 71 |
-
mocker.patch("blossomtune_gradio.
|
| 72 |
-
|
| 73 |
"another_user", "another@example.com", ""
|
| 74 |
)
|
| 75 |
-
assert
|
| 76 |
assert download is None
|
| 77 |
assert message == "mock_federation_full_md"
|
| 78 |
|
| 79 |
-
def test_user_activation_success(self,
|
| 80 |
"""Verify a user can successfully activate their account."""
|
| 81 |
# Setup: Add a pending, non-activated user
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
|
|
|
|
|
|
| 86 |
)
|
| 87 |
-
|
|
|
|
| 88 |
|
| 89 |
-
|
| 90 |
"test_user", "test@example.com", "ABCDEF"
|
| 91 |
)
|
| 92 |
-
assert
|
| 93 |
assert download is None
|
| 94 |
assert message == "mock_activation_successful_md"
|
| 95 |
# Verify the user is now activated
|
| 96 |
-
|
| 97 |
-
|
| 98 |
)
|
| 99 |
-
assert
|
| 100 |
|
| 101 |
-
def test_user_activation_invalid_code(self,
|
| 102 |
"""Verify activation fails with an invalid code."""
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
|
|
|
| 108 |
)
|
| 109 |
-
|
|
|
|
| 110 |
|
| 111 |
-
|
| 112 |
"test_user", "test@example.com", "WRONGCODE"
|
| 113 |
)
|
| 114 |
-
assert
|
| 115 |
assert download is None
|
| 116 |
assert message == "mock_activation_invalid_md"
|
| 117 |
|
| 118 |
-
def test_status_check_approved(self,
|
| 119 |
"""Verify the status check for an approved user."""
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
)
|
| 134 |
-
in_memory_db.commit()
|
| 135 |
-
success, message, download = fed.check_participant_status(
|
| 136 |
"approved_user", "approved@example.com", "GHIJKL"
|
| 137 |
)
|
| 138 |
-
assert
|
| 139 |
assert download is not None
|
| 140 |
assert "mock_status_approved_md" in message
|
| 141 |
|
| 142 |
|
| 143 |
class TestManageRequest:
|
| 144 |
-
"""Test suite for the manage_request function."""
|
| 145 |
|
| 146 |
-
def test_approve_success(self,
|
| 147 |
"""Verify successful approval of a participant."""
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
"pending@example.com",
|
| 158 |
-
"CODE",
|
| 159 |
-
1,
|
| 160 |
-
),
|
| 161 |
-
)
|
| 162 |
-
in_memory_db.commit()
|
| 163 |
|
| 164 |
success, message = fed.manage_request("PENDING1", "10", "approve")
|
| 165 |
assert success is True
|
| 166 |
assert "is allowed to join" in message
|
| 167 |
|
| 168 |
# Verify status in DB
|
| 169 |
-
|
| 170 |
-
|
| 171 |
)
|
| 172 |
-
status
|
| 173 |
-
assert
|
| 174 |
-
assert partition_id == 10
|
| 175 |
|
| 176 |
-
def test_approve_not_activated(self,
|
| 177 |
"""Verify approval fails if the user is not activated."""
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
"pending2@example.com",
|
| 187 |
-
"CODE",
|
| 188 |
-
0,
|
| 189 |
-
),
|
| 190 |
-
)
|
| 191 |
-
in_memory_db.commit()
|
| 192 |
success, message = fed.manage_request("PENDING2", "11", "approve")
|
| 193 |
assert success is False
|
| 194 |
assert message == "mock_participant_not_activated_warning_md"
|
| 195 |
|
| 196 |
-
def test_deny_success(self,
|
| 197 |
"""Verify successful denial of a participant."""
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
"pending3@example.com",
|
| 207 |
-
"CODE",
|
| 208 |
-
1,
|
| 209 |
-
),
|
| 210 |
-
)
|
| 211 |
-
in_memory_db.commit()
|
| 212 |
success, message = fed.manage_request("PENDING3", "", "deny")
|
| 213 |
assert success is True
|
| 214 |
assert "is not allowed to join" in message
|
| 215 |
|
| 216 |
# Verify status in DB
|
| 217 |
-
|
| 218 |
-
|
|
|
|
|
|
|
| 219 |
|
| 220 |
|
| 221 |
-
def test_get_next_partition_id(
|
| 222 |
"""Verify the logic for finding the next available partition ID."""
|
| 223 |
-
cursor = in_memory_db.cursor()
|
| 224 |
# No approved users yet
|
| 225 |
assert fed.get_next_partion_id() == 0
|
| 226 |
|
| 227 |
-
# Add some approved users with assigned partitions
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 231 |
)
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 235 |
)
|
| 236 |
-
|
| 237 |
assert fed.get_next_partion_id() == 2
|
| 238 |
|
| 239 |
# Add another user, skipping a partition ID
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 243 |
)
|
| 244 |
-
|
| 245 |
assert fed.get_next_partion_id() == 2
|
|
|
|
| 1 |
import pytest
|
| 2 |
+
from datetime import datetime
|
| 3 |
|
| 4 |
from blossomtune_gradio import federation as fed
|
| 5 |
+
from blossomtune_gradio.database import Request
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
|
| 8 |
@pytest.fixture
|
|
|
|
| 28 |
|
| 29 |
|
| 30 |
class TestCheckParticipantStatus:
|
| 31 |
+
"""Test suite for the check_participant_status function using SQLAlchemy."""
|
| 32 |
|
| 33 |
+
def test_new_user_registration_success(self, db_session, mock_settings, mock_mail):
|
|
|
|
|
|
|
| 34 |
"""Verify successful registration for a new user."""
|
| 35 |
mock_mail.return_value = (True, "")
|
| 36 |
+
approved, message, download = fed.check_participant_status(
|
| 37 |
"new_user", "new@example.com", ""
|
| 38 |
)
|
| 39 |
+
assert approved is False
|
| 40 |
assert download is None
|
| 41 |
assert message == "mock_registration_submitted_md"
|
| 42 |
|
| 43 |
# Verify the user was added to the database
|
| 44 |
+
request = db_session.query(Request).filter_by(hf_handle="new_user").first()
|
| 45 |
+
assert request is not None
|
| 46 |
+
assert request.email == "new@example.com"
|
| 47 |
|
| 48 |
+
def test_new_user_invalid_email(self, db_session, mock_settings):
|
| 49 |
"""Verify registration fails with an invalid email."""
|
| 50 |
+
approved, message, download = fed.check_participant_status(
|
| 51 |
"user", "invalid-email", ""
|
| 52 |
)
|
| 53 |
+
assert approved is False
|
| 54 |
assert download is None
|
| 55 |
assert message == "mock_invalid_email_md"
|
| 56 |
|
| 57 |
+
def test_new_user_federation_full(self, db_session, mock_settings, mocker):
|
| 58 |
"""Verify registration fails when the federation is full."""
|
| 59 |
+
mocker.patch("blossomtune_gradio.config.MAX_NUM_NODES", 0)
|
| 60 |
+
approved, message, download = fed.check_participant_status(
|
| 61 |
"another_user", "another@example.com", ""
|
| 62 |
)
|
| 63 |
+
assert approved is False
|
| 64 |
assert download is None
|
| 65 |
assert message == "mock_federation_full_md"
|
| 66 |
|
| 67 |
+
def test_user_activation_success(self, db_session, mock_settings):
|
| 68 |
"""Verify a user can successfully activate their account."""
|
| 69 |
# Setup: Add a pending, non-activated user
|
| 70 |
+
pending_user = Request(
|
| 71 |
+
participant_id="PID123",
|
| 72 |
+
hf_handle="test_user",
|
| 73 |
+
email="test@example.com",
|
| 74 |
+
activation_code="ABCDEF",
|
| 75 |
+
is_activated=0,
|
| 76 |
)
|
| 77 |
+
db_session.add(pending_user)
|
| 78 |
+
db_session.commit()
|
| 79 |
|
| 80 |
+
approved, message, download = fed.check_participant_status(
|
| 81 |
"test_user", "test@example.com", "ABCDEF"
|
| 82 |
)
|
| 83 |
+
assert approved is False
|
| 84 |
assert download is None
|
| 85 |
assert message == "mock_activation_successful_md"
|
| 86 |
# Verify the user is now activated
|
| 87 |
+
activated_user = (
|
| 88 |
+
db_session.query(Request).filter_by(hf_handle="test_user").first()
|
| 89 |
)
|
| 90 |
+
assert activated_user.is_activated == 1
|
| 91 |
|
| 92 |
+
def test_user_activation_invalid_code(self, db_session, mock_settings):
|
| 93 |
"""Verify activation fails with an invalid code."""
|
| 94 |
+
pending_user = Request(
|
| 95 |
+
participant_id="PID123",
|
| 96 |
+
hf_handle="test_user",
|
| 97 |
+
email="test@example.com",
|
| 98 |
+
activation_code="ABCDEF",
|
| 99 |
+
is_activated=0,
|
| 100 |
)
|
| 101 |
+
db_session.add(pending_user)
|
| 102 |
+
db_session.commit()
|
| 103 |
|
| 104 |
+
approved, message, download = fed.check_participant_status(
|
| 105 |
"test_user", "test@example.com", "WRONGCODE"
|
| 106 |
)
|
| 107 |
+
assert approved is False
|
| 108 |
assert download is None
|
| 109 |
assert message == "mock_activation_invalid_md"
|
| 110 |
|
| 111 |
+
def test_status_check_approved(self, db_session, mock_settings):
|
| 112 |
"""Verify the status check for an approved user."""
|
| 113 |
+
approved_user = Request(
|
| 114 |
+
participant_id="PID456",
|
| 115 |
+
status="approved",
|
| 116 |
+
hf_handle="approved_user",
|
| 117 |
+
email="approved@example.com",
|
| 118 |
+
activation_code="GHIJKL",
|
| 119 |
+
is_activated=1,
|
| 120 |
+
partition_id=5,
|
| 121 |
+
)
|
| 122 |
+
db_session.add(approved_user)
|
| 123 |
+
db_session.commit()
|
| 124 |
+
|
| 125 |
+
approved, message, download = fed.check_participant_status(
|
|
|
|
|
|
|
|
|
|
| 126 |
"approved_user", "approved@example.com", "GHIJKL"
|
| 127 |
)
|
| 128 |
+
assert approved is True
|
| 129 |
assert download is not None
|
| 130 |
assert "mock_status_approved_md" in message
|
| 131 |
|
| 132 |
|
| 133 |
class TestManageRequest:
|
| 134 |
+
"""Test suite for the manage_request function using SQLAlchemy."""
|
| 135 |
|
| 136 |
+
def test_approve_success(self, db_session):
|
| 137 |
"""Verify successful approval of a participant."""
|
| 138 |
+
pending_user = Request(
|
| 139 |
+
participant_id="PENDING1",
|
| 140 |
+
status="pending",
|
| 141 |
+
hf_handle="pending_user",
|
| 142 |
+
email="pending@example.com",
|
| 143 |
+
is_activated=1,
|
| 144 |
+
)
|
| 145 |
+
db_session.add(pending_user)
|
| 146 |
+
db_session.commit()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 147 |
|
| 148 |
success, message = fed.manage_request("PENDING1", "10", "approve")
|
| 149 |
assert success is True
|
| 150 |
assert "is allowed to join" in message
|
| 151 |
|
| 152 |
# Verify status in DB
|
| 153 |
+
updated_user = (
|
| 154 |
+
db_session.query(Request).filter_by(participant_id="PENDING1").first()
|
| 155 |
)
|
| 156 |
+
assert updated_user.status == "approved"
|
| 157 |
+
assert updated_user.partition_id == 10
|
|
|
|
| 158 |
|
| 159 |
+
def test_approve_not_activated(self, db_session, mock_settings):
|
| 160 |
"""Verify approval fails if the user is not activated."""
|
| 161 |
+
pending_user = Request(
|
| 162 |
+
participant_id="PENDING2",
|
| 163 |
+
status="pending",
|
| 164 |
+
hf_handle="pending_user2",
|
| 165 |
+
is_activated=0,
|
| 166 |
+
)
|
| 167 |
+
db_session.add(pending_user)
|
| 168 |
+
db_session.commit()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 169 |
success, message = fed.manage_request("PENDING2", "11", "approve")
|
| 170 |
assert success is False
|
| 171 |
assert message == "mock_participant_not_activated_warning_md"
|
| 172 |
|
| 173 |
+
def test_deny_success(self, db_session):
|
| 174 |
"""Verify successful denial of a participant."""
|
| 175 |
+
pending_user = Request(
|
| 176 |
+
participant_id="PENDING3",
|
| 177 |
+
status="pending",
|
| 178 |
+
hf_handle="pending_user3",
|
| 179 |
+
is_activated=1,
|
| 180 |
+
)
|
| 181 |
+
db_session.add(pending_user)
|
| 182 |
+
db_session.commit()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 183 |
success, message = fed.manage_request("PENDING3", "", "deny")
|
| 184 |
assert success is True
|
| 185 |
assert "is not allowed to join" in message
|
| 186 |
|
| 187 |
# Verify status in DB
|
| 188 |
+
updated_user = (
|
| 189 |
+
db_session.query(Request).filter_by(participant_id="PENDING3").first()
|
| 190 |
+
)
|
| 191 |
+
assert updated_user.status == "denied"
|
| 192 |
|
| 193 |
|
| 194 |
+
def test_get_next_partition_id(db_session):
|
| 195 |
"""Verify the logic for finding the next available partition ID."""
|
|
|
|
| 196 |
# No approved users yet
|
| 197 |
assert fed.get_next_partion_id() == 0
|
| 198 |
|
| 199 |
+
# Add some approved users with assigned partitions
|
| 200 |
+
db_session.add(
|
| 201 |
+
Request(
|
| 202 |
+
participant_id="P1",
|
| 203 |
+
status="approved",
|
| 204 |
+
partition_id=0,
|
| 205 |
+
timestamp=datetime.utcnow(),
|
| 206 |
+
)
|
| 207 |
)
|
| 208 |
+
db_session.add(
|
| 209 |
+
Request(
|
| 210 |
+
participant_id="P2",
|
| 211 |
+
status="approved",
|
| 212 |
+
partition_id=1,
|
| 213 |
+
timestamp=datetime.utcnow(),
|
| 214 |
+
)
|
| 215 |
)
|
| 216 |
+
db_session.commit()
|
| 217 |
assert fed.get_next_partion_id() == 2
|
| 218 |
|
| 219 |
# Add another user, skipping a partition ID
|
| 220 |
+
db_session.add(
|
| 221 |
+
Request(
|
| 222 |
+
participant_id="P3",
|
| 223 |
+
status="approved",
|
| 224 |
+
partition_id=3,
|
| 225 |
+
timestamp=datetime.utcnow(),
|
| 226 |
+
)
|
| 227 |
)
|
| 228 |
+
db_session.commit()
|
| 229 |
assert fed.get_next_partion_id() == 2
|
tests/test_processing.py
CHANGED
|
@@ -2,6 +2,7 @@ import pytest
|
|
| 2 |
from unittest.mock import MagicMock, patch
|
| 3 |
|
| 4 |
from blossomtune_gradio import processing
|
|
|
|
| 5 |
|
| 6 |
|
| 7 |
@pytest.fixture(autouse=True)
|
|
@@ -12,8 +13,6 @@ def reset_process_store():
|
|
| 12 |
"""
|
| 13 |
processing.process_store = {"superlink": None, "runner": None}
|
| 14 |
yield
|
| 15 |
-
# Teardown is not strictly necessary as it's reset at the start,
|
| 16 |
-
# but it's good practice.
|
| 17 |
processing.process_store = {"superlink": None, "runner": None}
|
| 18 |
|
| 19 |
|
|
@@ -29,18 +28,15 @@ def test_start_superlink_success(mock_which, mock_thread):
|
|
| 29 |
assert success is True
|
| 30 |
assert message == "Superlink process started."
|
| 31 |
mock_thread.assert_called_once()
|
| 32 |
-
# Check that the thread is targeting the run_process function
|
| 33 |
call_args = mock_thread.call_args
|
| 34 |
assert call_args.kwargs["target"] == processing.run_process
|
| 35 |
-
# Check that the command is correct
|
| 36 |
assert call_args.kwargs["args"][0] == ["/fake/path/flower-superlink", "--insecure"]
|
| 37 |
|
| 38 |
|
| 39 |
def test_start_superlink_already_running(mocker):
|
| 40 |
"""Verify that start_superlink returns False if a process is already running."""
|
| 41 |
-
# Mock a running process
|
| 42 |
mock_process = MagicMock()
|
| 43 |
-
mock_process.poll.return_value = None
|
| 44 |
processing.process_store["superlink"] = mock_process
|
| 45 |
|
| 46 |
success, message = processing.start_superlink()
|
|
@@ -49,33 +45,60 @@ def test_start_superlink_already_running(mocker):
|
|
| 49 |
|
| 50 |
|
| 51 |
@patch("blossomtune_gradio.processing.os.path.exists", return_value=True)
|
| 52 |
-
@patch("blossomtune_gradio.processing.sqlite3.connect")
|
| 53 |
@patch("blossomtune_gradio.processing.threading.Thread")
|
| 54 |
@patch("blossomtune_gradio.processing.shutil.which", return_value="/fake/path/flwr")
|
| 55 |
-
def
|
| 56 |
-
|
| 57 |
-
|
|
|
|
|
|
|
| 58 |
mock_superlink = MagicMock()
|
| 59 |
mock_superlink.poll.return_value = None
|
| 60 |
processing.process_store["superlink"] = mock_superlink
|
| 61 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
success, message = processing.start_runner("app.main", "run1", "10")
|
| 63 |
|
|
|
|
| 64 |
assert success is True
|
| 65 |
assert message == "Federation Run is starting...."
|
| 66 |
mock_thread.assert_called_once()
|
| 67 |
-
mock_sqlite.assert_called_once() # Check if DB was updated
|
| 68 |
|
| 69 |
|
| 70 |
-
def
|
| 71 |
-
"""Verify start_runner fails if superlink is not running."""
|
| 72 |
processing.process_store["superlink"] = None
|
| 73 |
success, message = processing.start_runner("app.main", "run1", "10")
|
| 74 |
assert success is False
|
| 75 |
-
assert "Superlink is not running" in message
|
| 76 |
|
| 77 |
|
| 78 |
-
def test_start_runner_missing_args():
|
| 79 |
"""Verify start_runner fails if arguments are missing."""
|
| 80 |
mock_superlink = MagicMock()
|
| 81 |
mock_superlink.poll.return_value = None
|
|
@@ -87,7 +110,7 @@ def test_start_runner_missing_args():
|
|
| 87 |
|
| 88 |
|
| 89 |
@patch("blossomtune_gradio.processing.os.path.exists", return_value=False)
|
| 90 |
-
def test_start_runner_app_path_not_found(mock_exists,
|
| 91 |
"""Verify start_runner fails if the app path doesn't exist."""
|
| 92 |
mock_superlink = MagicMock()
|
| 93 |
mock_superlink.poll.return_value = None
|
|
@@ -101,7 +124,7 @@ def test_start_runner_app_path_not_found(mock_exists, in_memory_db):
|
|
| 101 |
def test_stop_process_running():
|
| 102 |
"""Verify stop_process terminates a running process."""
|
| 103 |
mock_process = MagicMock()
|
| 104 |
-
mock_process.poll.return_value = None
|
| 105 |
processing.process_store["superlink"] = mock_process
|
| 106 |
|
| 107 |
processing.stop_process("superlink")
|
|
@@ -118,7 +141,6 @@ def test_stop_process_not_running(mocker):
|
|
| 118 |
|
| 119 |
processing.stop_process("superlink")
|
| 120 |
|
| 121 |
-
# Check that the specific "no process was running" log was called
|
| 122 |
log_mock.assert_any_call(
|
| 123 |
"[Superlink] Stop command received, but no process was running."
|
| 124 |
)
|
|
|
|
| 2 |
from unittest.mock import MagicMock, patch
|
| 3 |
|
| 4 |
from blossomtune_gradio import processing
|
| 5 |
+
from blossomtune_gradio.database import Config
|
| 6 |
|
| 7 |
|
| 8 |
@pytest.fixture(autouse=True)
|
|
|
|
| 13 |
"""
|
| 14 |
processing.process_store = {"superlink": None, "runner": None}
|
| 15 |
yield
|
|
|
|
|
|
|
| 16 |
processing.process_store = {"superlink": None, "runner": None}
|
| 17 |
|
| 18 |
|
|
|
|
| 28 |
assert success is True
|
| 29 |
assert message == "Superlink process started."
|
| 30 |
mock_thread.assert_called_once()
|
|
|
|
| 31 |
call_args = mock_thread.call_args
|
| 32 |
assert call_args.kwargs["target"] == processing.run_process
|
|
|
|
| 33 |
assert call_args.kwargs["args"][0] == ["/fake/path/flower-superlink", "--insecure"]
|
| 34 |
|
| 35 |
|
| 36 |
def test_start_superlink_already_running(mocker):
|
| 37 |
"""Verify that start_superlink returns False if a process is already running."""
|
|
|
|
| 38 |
mock_process = MagicMock()
|
| 39 |
+
mock_process.poll.return_value = None
|
| 40 |
processing.process_store["superlink"] = mock_process
|
| 41 |
|
| 42 |
success, message = processing.start_superlink()
|
|
|
|
| 45 |
|
| 46 |
|
| 47 |
@patch("blossomtune_gradio.processing.os.path.exists", return_value=True)
|
|
|
|
| 48 |
@patch("blossomtune_gradio.processing.threading.Thread")
|
| 49 |
@patch("blossomtune_gradio.processing.shutil.which", return_value="/fake/path/flwr")
|
| 50 |
+
def test_start_runner_success_internal_superlink(
|
| 51 |
+
mock_which, mock_thread, mock_exists, db_session
|
| 52 |
+
):
|
| 53 |
+
"""Verify start_runner succeeds with an internal superlink and updates the DB."""
|
| 54 |
+
# Arrange: Mock a running internal superlink process
|
| 55 |
mock_superlink = MagicMock()
|
| 56 |
mock_superlink.poll.return_value = None
|
| 57 |
processing.process_store["superlink"] = mock_superlink
|
| 58 |
|
| 59 |
+
# Act
|
| 60 |
+
success, message = processing.start_runner("app.main", "run1", "15")
|
| 61 |
+
|
| 62 |
+
# Assert
|
| 63 |
+
assert success is True
|
| 64 |
+
assert message == "Federation Run is starting...."
|
| 65 |
+
mock_thread.assert_called_once()
|
| 66 |
+
|
| 67 |
+
# Verify DB was updated using SQLAlchemy
|
| 68 |
+
config_entry = db_session.query(Config).filter_by(key="num_partitions").first()
|
| 69 |
+
assert config_entry is not None
|
| 70 |
+
assert config_entry.value == "15"
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
@patch("blossomtune_gradio.processing.os.path.exists", return_value=True)
|
| 74 |
+
@patch("blossomtune_gradio.processing.threading.Thread")
|
| 75 |
+
@patch("blossomtune_gradio.processing.shutil.which", return_value="/fake/path/flwr")
|
| 76 |
+
def test_start_runner_success_external_superlink(
|
| 77 |
+
mock_which, mock_thread, mock_exists, db_session, mocker
|
| 78 |
+
):
|
| 79 |
+
"""Verify start_runner succeeds with an external superlink."""
|
| 80 |
+
# Arrange
|
| 81 |
+
mocker.patch("blossomtune_gradio.config.SUPERLINK_MODE", "external")
|
| 82 |
+
mocker.patch("blossomtune_gradio.util.is_port_open", return_value=True)
|
| 83 |
+
|
| 84 |
+
# Act
|
| 85 |
success, message = processing.start_runner("app.main", "run1", "10")
|
| 86 |
|
| 87 |
+
# Assert
|
| 88 |
assert success is True
|
| 89 |
assert message == "Federation Run is starting...."
|
| 90 |
mock_thread.assert_called_once()
|
|
|
|
| 91 |
|
| 92 |
|
| 93 |
+
def test_start_runner_internal_superlink_not_running(db_session):
|
| 94 |
+
"""Verify start_runner fails if internal superlink is not running."""
|
| 95 |
processing.process_store["superlink"] = None
|
| 96 |
success, message = processing.start_runner("app.main", "run1", "10")
|
| 97 |
assert success is False
|
| 98 |
+
assert "Internal Superlink is not running" in message
|
| 99 |
|
| 100 |
|
| 101 |
+
def test_start_runner_missing_args(db_session):
|
| 102 |
"""Verify start_runner fails if arguments are missing."""
|
| 103 |
mock_superlink = MagicMock()
|
| 104 |
mock_superlink.poll.return_value = None
|
|
|
|
| 110 |
|
| 111 |
|
| 112 |
@patch("blossomtune_gradio.processing.os.path.exists", return_value=False)
|
| 113 |
+
def test_start_runner_app_path_not_found(mock_exists, db_session):
|
| 114 |
"""Verify start_runner fails if the app path doesn't exist."""
|
| 115 |
mock_superlink = MagicMock()
|
| 116 |
mock_superlink.poll.return_value = None
|
|
|
|
| 124 |
def test_stop_process_running():
|
| 125 |
"""Verify stop_process terminates a running process."""
|
| 126 |
mock_process = MagicMock()
|
| 127 |
+
mock_process.poll.return_value = None
|
| 128 |
processing.process_store["superlink"] = mock_process
|
| 129 |
|
| 130 |
processing.stop_process("superlink")
|
|
|
|
| 141 |
|
| 142 |
processing.stop_process("superlink")
|
| 143 |
|
|
|
|
| 144 |
log_mock.assert_any_call(
|
| 145 |
"[Superlink] Stop command received, but no process was running."
|
| 146 |
)
|
uv.lock
CHANGED
|
The diff for this file is too large to render.
See raw diff
|
|
|