mrs83 commited on
Commit
1c0aad9
·
unverified ·
1 Parent(s): 2037a88

Migrate to SQLAlchemy and Alembic for database migrations. (#19)

Browse files

* Migrate to SQLAlchemy and Alembic for database migrations.
* remove db.init() call from gradio app
* run migrations at app startup
* fix run migrations
* Convert SQLAlchemy rows to simple lists
* add missing config key

alembic.ini ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # A generic, single database configuration.
2
+
3
+ [alembic]
4
+
5
+ # database URL. This is consumed by the user-maintained env.py script only.
6
+ # other means of configuring database URLs may be customized within the env.py
7
+ # file.
8
+ sqlalchemy.url = driver://user:pass@localhost/dbname
9
+
10
+
11
+ # Logging configuration
12
+ [loggers]
13
+ keys = root,sqlalchemy,alembic
14
+
15
+ [handlers]
16
+ keys = console
17
+
18
+ [formatters]
19
+ keys = generic
20
+
21
+ [logger_root]
22
+ level = WARNING
23
+ handlers = console
24
+ qualname =
25
+
26
+ [logger_sqlalchemy]
27
+ level = WARNING
28
+ handlers =
29
+ qualname = sqlalchemy.engine
30
+
31
+ [logger_alembic]
32
+ level = INFO
33
+ handlers =
34
+ qualname = alembic
35
+
36
+ [handler_console]
37
+ class = StreamHandler
38
+ args = (sys.stderr,)
39
+ level = NOTSET
40
+ formatter = generic
41
+
42
+ [formatter_generic]
43
+ format = %(levelname)-5.5s [%(name)s] %(message)s
44
+ datefmt = %H:%M:%S
alembic/README ADDED
@@ -0,0 +1 @@
 
 
1
+ pyproject configuration, based on the generic configuration.
alembic/env.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from logging.config import fileConfig
2
+
3
+ from sqlalchemy import engine_from_config
4
+ from sqlalchemy import pool
5
+
6
+ from alembic import context
7
+ from blossomtune_gradio import config as cfg
8
+ from blossomtune_gradio.database import Base
9
+
10
+
11
+ # this is the Alembic Config object, which provides
12
+ # access to the values within the .ini file in use.
13
+ config = context.config
14
+ config.set_main_option("sqlalchemy.url", cfg.SQLALCHEMY_URL)
15
+
16
+ # Interpret the config file for Python logging.
17
+ # This line sets up loggers basically.
18
+ if config.config_file_name is not None:
19
+ fileConfig(config.config_file_name)
20
+
21
+ # add your model's MetaData object here
22
+ # for 'autogenerate' support
23
+ target_metadata = Base.metadata
24
+
25
+ # other values from the config, defined by the needs of env.py,
26
+ # can be acquired:
27
+ # my_important_option = config.get_main_option("my_important_option")
28
+ # ... etc.
29
+
30
+
31
+ def run_migrations_offline() -> None:
32
+ """Run migrations in 'offline' mode.
33
+
34
+ This configures the context with just a URL
35
+ and not an Engine, though an Engine is acceptable
36
+ here as well. By skipping the Engine creation
37
+ we don't even need a DBAPI to be available.
38
+
39
+ Calls to context.execute() here emit the given string to the
40
+ script output.
41
+
42
+ """
43
+ url = config.get_main_option("sqlalchemy.url")
44
+ context.configure(
45
+ url=url,
46
+ target_metadata=target_metadata,
47
+ literal_binds=True,
48
+ dialect_opts={"paramstyle": "named"},
49
+ )
50
+
51
+ with context.begin_transaction():
52
+ context.run_migrations()
53
+
54
+
55
+ def run_migrations_online() -> None:
56
+ """Run migrations in 'online' mode.
57
+
58
+ In this scenario we need to create an Engine
59
+ and associate a connection with the context.
60
+
61
+ """
62
+ connectable = engine_from_config(
63
+ config.get_section(config.config_ini_section, {}),
64
+ prefix="sqlalchemy.",
65
+ poolclass=pool.NullPool,
66
+ )
67
+
68
+ with connectable.connect() as connection:
69
+ context.configure(connection=connection, target_metadata=target_metadata)
70
+
71
+ with context.begin_transaction():
72
+ context.run_migrations()
73
+
74
+
75
+ if context.is_offline_mode():
76
+ run_migrations_offline()
77
+ else:
78
+ run_migrations_online()
alembic/script.py.mako ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """${message}
2
+
3
+ Revision ID: ${up_revision}
4
+ Revises: ${down_revision | comma,n}
5
+ Create Date: ${create_date}
6
+
7
+ """
8
+ from typing import Sequence, Union
9
+
10
+ from alembic import op
11
+ import sqlalchemy as sa
12
+ ${imports if imports else ""}
13
+
14
+ # revision identifiers, used by Alembic.
15
+ revision: str = ${repr(up_revision)}
16
+ down_revision: Union[str, Sequence[str], None] = ${repr(down_revision)}
17
+ branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
18
+ depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
19
+
20
+
21
+ def upgrade() -> None:
22
+ """Upgrade schema."""
23
+ ${upgrades if upgrades else "pass"}
24
+
25
+
26
+ def downgrade() -> None:
27
+ """Downgrade schema."""
28
+ ${downgrades if downgrades else "pass"}
alembic/versions/0b01b44ab005_initial.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """initial
2
+
3
+ Revision ID: 0b01b44ab005
4
+ Revises:
5
+ Create Date: 2025-10-08 12:13:39.666969
6
+
7
+ """
8
+
9
+ from typing import Sequence, Union
10
+
11
+ from alembic import op
12
+ import sqlalchemy as sa
13
+
14
+
15
+ # revision identifiers, used by Alembic.
16
+ revision: str = "0b01b44ab005"
17
+ down_revision: Union[str, Sequence[str], None] = None
18
+ branch_labels: Union[str, Sequence[str], None] = None
19
+ depends_on: Union[str, Sequence[str], None] = None
20
+
21
+
22
+ def upgrade() -> None:
23
+ """Upgrade schema."""
24
+ # ### commands auto generated by Alembic - please adjust! ###
25
+ op.create_table(
26
+ "config",
27
+ sa.Column("key", sa.String(), nullable=False),
28
+ sa.Column("value", sa.String(), nullable=False),
29
+ sa.PrimaryKeyConstraint("key"),
30
+ )
31
+ op.create_table(
32
+ "requests",
33
+ sa.Column("participant_id", sa.String(), nullable=False),
34
+ sa.Column("status", sa.String(), nullable=False),
35
+ sa.Column(
36
+ "timestamp",
37
+ sa.DateTime(),
38
+ server_default=sa.text("(CURRENT_TIMESTAMP)"),
39
+ nullable=False,
40
+ ),
41
+ sa.Column("partition_id", sa.Integer(), nullable=True),
42
+ sa.Column("email", sa.String(), nullable=True),
43
+ sa.Column("hf_handle", sa.String(), nullable=True),
44
+ sa.Column("activation_code", sa.String(), nullable=True),
45
+ sa.Column("is_activated", sa.Integer(), nullable=False),
46
+ sa.PrimaryKeyConstraint("participant_id"),
47
+ )
48
+ # ### end Alembic commands ###
49
+
50
+
51
+ def downgrade() -> None:
52
+ """Downgrade schema."""
53
+ # ### commands auto generated by Alembic - please adjust! ###
54
+ op.drop_table("requests")
55
+ op.drop_table("config")
56
+ # ### end Alembic commands ###
app.py DELETED
@@ -1,5 +0,0 @@
1
- from blossomtune_gradio.gradio_app import demo
2
-
3
-
4
- if __name__ == "__main__":
5
- demo.launch()
 
 
 
 
 
 
blossomtune_gradio/__main__.py CHANGED
@@ -1,5 +1,9 @@
 
 
1
  from blossomtune_gradio.gradio_app import demo
2
 
3
 
4
  if __name__ == "__main__":
 
 
5
  demo.launch()
 
1
+ from blossomtune_gradio import config as cfg
2
+ from blossomtune_gradio import database as db
3
  from blossomtune_gradio.gradio_app import demo
4
 
5
 
6
  if __name__ == "__main__":
7
+ if cfg.RUN_MIGRATIONS_ON_STARTUP:
8
+ db.run_migrations()
9
  demo.launch()
blossomtune_gradio/config.py CHANGED
@@ -13,6 +13,7 @@ SPACE_OWNER = os.getenv("SPACE_OWNER", SPACE_ID.split("/")[0] if SPACE_ID else N
13
  DB_PATH = (
14
  "/data/db/federation.db" if os.path.isdir("/data/db") else "./data/db/federation.db"
15
  )
 
16
  MAX_NUM_NODES = int(os.getenv("MAX_NUM_NODES", "20"))
17
  SMTP_SENDER = os.getenv("SMTP_SENDER", "hello@ethicalabs.ai")
18
  SMTP_SERVER = os.getenv("SMTP_SERVER", "localhost")
@@ -25,6 +26,9 @@ SUPERLINK_HOST = os.getenv("SUPERLINK_HOST", "127.0.0.1:9092")
25
  SUPERLINK_PORT = int(os.getenv("SUPERLINK_PORT", 9092))
26
  SUPERLINK_CONTROL_API_PORT = int(os.getenv("SUPERLINK_CONTROL_API_PORT", 9093))
27
  SUPERLINK_MODE = os.getenv("SUPERLINK_MODE", "internal").lower() # Or external
 
 
 
28
 
29
  # TLS root cert path. For production only.
30
  TLS_CERT_DIR = os.getenv("TLS_CERT_DIR", "./certs/")
 
13
  DB_PATH = (
14
  "/data/db/federation.db" if os.path.isdir("/data/db") else "./data/db/federation.db"
15
  )
16
+ SQLALCHEMY_URL = f"sqlite:///{os.path.abspath(DB_PATH)}"
17
  MAX_NUM_NODES = int(os.getenv("MAX_NUM_NODES", "20"))
18
  SMTP_SENDER = os.getenv("SMTP_SENDER", "hello@ethicalabs.ai")
19
  SMTP_SERVER = os.getenv("SMTP_SERVER", "localhost")
 
26
  SUPERLINK_PORT = int(os.getenv("SUPERLINK_PORT", 9092))
27
  SUPERLINK_CONTROL_API_PORT = int(os.getenv("SUPERLINK_CONTROL_API_PORT", 9093))
28
  SUPERLINK_MODE = os.getenv("SUPERLINK_MODE", "internal").lower() # Or external
29
+ RUN_MIGRATIONS_ON_STARTUP = util.strtobool(
30
+ os.getenv("RUN_MIGRATIONS_ON_STARTUP", "true")
31
+ ) # Set to false in prod.
32
 
33
  # TLS root cert path. For production only.
34
  TLS_CERT_DIR = os.getenv("TLS_CERT_DIR", "./certs/")
blossomtune_gradio/database.py CHANGED
@@ -1,34 +1,65 @@
1
- import sqlite3
 
 
 
 
2
  from blossomtune_gradio import config as cfg
3
 
4
 
5
- def init():
6
- """Initializes the DB, adds columns, and creates config table."""
7
- with sqlite3.connect(cfg.DB_PATH) as conn:
8
- cursor = conn.cursor()
9
- cursor.execute("""
10
- CREATE TABLE IF NOT EXISTS requests (
11
- participant_id TEXT PRIMARY KEY,
12
- status TEXT NOT NULL,
13
- timestamp TEXT NOT NULL,
14
- partition_id INTEGER,
15
- email TEXT,
16
- hf_handle TEXT,
17
- activation_code TEXT,
18
- is_activated INTEGER DEFAULT 0
19
- )
20
- """)
21
-
22
- # Create a simple key-value config table
23
- cursor.execute("""
24
- CREATE TABLE IF NOT EXISTS config (
25
- key TEXT PRIMARY KEY,
26
- value TEXT NOT NULL
27
- )
28
- """)
29
- # Initialize default number of partitions if not present
30
- cursor.execute(
31
- "INSERT OR IGNORE INTO config (key, value) VALUES (?, ?)",
32
- ("num_partitions", "10"),
33
  )
34
- conn.commit()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from alembic import config
2
+ from sqlalchemy import create_engine, Column, String, Integer, DateTime, func
3
+ from sqlalchemy.orm import sessionmaker, declarative_base
4
+
5
+
6
  from blossomtune_gradio import config as cfg
7
 
8
 
9
+ Base = declarative_base()
10
+ engine = create_engine(cfg.SQLALCHEMY_URL)
11
+
12
+ # The sessionmaker factory generates new Session objects when called.
13
+ SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
14
+
15
+
16
+ class Request(Base):
17
+ """
18
+ SQLAlchemy model for the 'requests' table.
19
+ This table stores information about participants wanting to join the federation.
20
+ """
21
+
22
+ __tablename__ = "requests"
23
+
24
+ participant_id = Column(String, primary_key=True)
25
+ status = Column(String, nullable=False, default="pending")
26
+ timestamp = Column(DateTime, nullable=False, server_default=func.now())
27
+ partition_id = Column(Integer, nullable=True)
28
+ email = Column(String, nullable=True)
29
+ hf_handle = Column(String, nullable=True)
30
+ activation_code = Column(String, nullable=True)
31
+ is_activated = Column(Integer, nullable=False, default=0)
32
+
33
+ def __repr__(self):
34
+ return (
35
+ f"<Request(participant_id='{self.participant_id}', status='{self.status}')>"
 
36
  )
37
+
38
+
39
+ class Config(Base):
40
+ """
41
+ SQLAlchemy model for the 'config' table.
42
+ A simple key-value store for application settings.
43
+ """
44
+
45
+ __tablename__ = "config"
46
+
47
+ key = Column(String, primary_key=True)
48
+ value = Column(String, nullable=False)
49
+
50
+ def __repr__(self):
51
+ return f"<Config(key='{self.key}', value='{self.value}')>"
52
+
53
+
54
+ def run_migrations():
55
+ """
56
+ Applies any pending Alembic migrations to the database.
57
+ This should be called on application startup.
58
+ """
59
+ print("Running database migrations...")
60
+ alembicArgs = [
61
+ "--raiseerr",
62
+ "upgrade",
63
+ "head",
64
+ ]
65
+ config.main(argv=alembicArgs)
blossomtune_gradio/federation.py CHANGED
@@ -1,13 +1,11 @@
1
  import string
2
  import secrets
3
- import sqlite3
4
-
5
- from datetime import datetime
6
 
7
  from blossomtune_gradio import config as cfg
8
  from blossomtune_gradio import mail
9
  from blossomtune_gradio import util
10
  from blossomtune_gradio.settings import settings
 
11
 
12
 
13
  def generate_participant_id(length=6):
@@ -23,105 +21,99 @@ def generate_activation_code(length=8):
23
 
24
 
25
  def check_participant_status(pid_to_check: str, email: str, activation_code: str):
26
- """Handles a participant's request to join, activate, or check status."""
27
- with sqlite3.connect(cfg.DB_PATH) as conn:
28
- cursor = conn.cursor()
 
 
 
 
 
 
29
  if activation_code:
30
- cursor.execute(
31
- "SELECT participant_id, status, partition_id, is_activated, activation_code FROM requests WHERE hf_handle = ? AND email = ? AND activation_code = ?",
32
- (pid_to_check, email, activation_code),
33
- )
34
- else:
35
- cursor.execute(
36
- "SELECT participant_id, status, partition_id, is_activated, activation_code FROM requests WHERE hf_handle = ? AND email = ?",
37
- (pid_to_check, email),
38
- )
39
- result = cursor.fetchone()
40
 
41
- cursor.execute("SELECT value FROM config WHERE key = 'num_partitions'")
42
- num_partitions_res = cursor.fetchone()
43
- num_partitions = num_partitions_res[0] if num_partitions_res else "10"
44
 
45
- # Case 1: New user registration
46
- if result is None:
47
- if activation_code:
48
- return (False, settings.get_text("activation_invalid_md"), None)
49
- if not util.validate_email(email):
50
- return (False, settings.get_text("invalid_email_md"), None)
51
-
52
- with sqlite3.connect(cfg.DB_PATH) as conn:
53
- cursor = conn.cursor()
54
- cursor.execute("SELECT COUNT(*) FROM requests WHERE status = 'approved'")
55
- approved_count = cursor.fetchone()[0]
 
 
 
 
56
  if approved_count >= cfg.MAX_NUM_NODES:
57
  return (False, settings.get_text("federation_full_md"), None)
58
 
59
- participant_id = generate_participant_id()
60
- new_activation_code = generate_activation_code()
61
- mail_sent, message = mail.send_activation_email(email, new_activation_code)
62
- if mail_sent:
63
- with sqlite3.connect(cfg.DB_PATH) as conn:
64
- conn.execute(
65
- "INSERT INTO requests (participant_id, status, timestamp, hf_handle, email, activation_code, is_activated) VALUES (?, ?, ?, ?, ?, ?, ?)",
66
- (
67
- participant_id,
68
- "pending",
69
- datetime.utcnow().isoformat(),
70
- pid_to_check,
71
- email,
72
- new_activation_code,
73
- 0,
74
- ),
75
- )
76
- return (True, settings.get_text("registration_submitted_md"), None)
77
- else:
78
- return (False, message, None)
79
-
80
- # Existing user
81
- participant_id, status, partition_id, is_activated, stored_code = result
82
-
83
- # Case 2: User is activating their account
84
- if not is_activated:
85
- if activation_code == stored_code:
86
- with sqlite3.connect(cfg.DB_PATH) as conn:
87
- conn.execute(
88
- "UPDATE requests SET is_activated = 1 WHERE hf_handle = ?",
89
- (pid_to_check,),
90
  )
91
- return (True, settings.get_text("activation_successful_md"), None)
92
- else:
93
- return (False, settings.get_text("activation_invalid_md"), None)
94
- else:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  if not activation_code:
96
- return (False, settings.get_text("missing_activation_code_md"))
97
-
98
- # Case 3: Activated user is checking their status
99
- if status == "approved":
100
- hostname = (
101
- "localhost"
102
- if not cfg.SPACE_ID
103
- else f"{cfg.SPACE_ID.split('/')[1]}-{cfg.SPACE_ID.split('/')[0]}.hf.space"
104
- )
105
- superlink_hostname = cfg.SUPERLINK_HOST or hostname
106
-
107
- connection_string = settings.get_text(
108
- "status_approved_md",
109
- participant_id=participant_id,
110
- partition_id=partition_id,
111
- superlink_hostname=superlink_hostname,
112
- superlink_port=cfg.SUPERLINK_PORT,
113
- num_partitions=num_partitions,
114
- )
115
- # TODO: build and provide .blossomfile for download
116
- return (True, connection_string, cfg.BLOSSOMTUNE_TLS_CERT_PATH)
117
- elif status == "pending":
118
- return (False, settings.get_text("status_pending_md"))
119
- else: # Denied
120
- return (
121
- False,
122
- settings.get_text("status_denied_md", participant_id=participant_id),
123
- None,
124
- )
 
125
 
126
 
127
  def manage_request(participant_id: str, partition_id: str, action: str):
@@ -129,29 +121,31 @@ def manage_request(participant_id: str, partition_id: str, action: str):
129
  if not participant_id:
130
  return False, "Please select a participant from the pending requests table."
131
 
132
- if action == "approve":
133
- if not partition_id or not partition_id.isdigit():
134
- return False, "Please provide a valid integer for the Partition ID."
 
 
 
135
 
136
- p_id_int = int(partition_id)
137
- with sqlite3.connect(cfg.DB_PATH) as conn:
138
- cursor = conn.cursor()
139
- cursor.execute(
140
- "SELECT is_activated FROM requests WHERE participant_id = ?",
141
- (participant_id,),
142
- )
143
- is_activated_res = cursor.fetchone()
144
- if not is_activated_res or not is_activated_res[0]:
145
  return (
146
  False,
147
  settings.get_text("participant_not_activated_warning_md"),
148
  )
149
 
150
- cursor.execute(
151
- "SELECT 1 FROM requests WHERE partition_id = ? AND status = 'approved'",
152
- (p_id_int,),
 
153
  )
154
- if cursor.fetchone():
 
155
  return (
156
  False,
157
  settings.get_text(
@@ -159,20 +153,17 @@ def manage_request(participant_id: str, partition_id: str, action: str):
159
  ),
160
  )
161
 
162
- conn.execute(
163
- "UPDATE requests SET status = ?, partition_id = ? WHERE participant_id = ?",
164
- ("approved", p_id_int, participant_id),
165
- )
166
  return (
167
  True,
168
  f"Participant {participant_id} is allowed to join the federation.",
169
  )
170
- else: # Deny
171
- with sqlite3.connect(cfg.DB_PATH) as conn:
172
- conn.execute(
173
- "UPDATE requests SET status = ?, partition_id = NULL WHERE participant_id = ?",
174
- ("denied", participant_id),
175
- )
176
  return (
177
  True,
178
  f"Participant {participant_id} is not allowed to join the federation.",
@@ -180,12 +171,14 @@ def manage_request(participant_id: str, partition_id: str, action: str):
180
 
181
 
182
  def get_next_partion_id() -> int:
183
- with sqlite3.connect(cfg.DB_PATH) as conn:
184
- cursor = conn.cursor()
185
- cursor.execute(
186
- "SELECT partition_id FROM requests WHERE status = 'approved' AND partition_id IS NOT NULL"
 
 
187
  )
188
- used_ids = {row[0] for row in cursor.fetchall()}
189
 
190
  next_id = 0
191
  while next_id in used_ids:
 
1
  import string
2
  import secrets
 
 
 
3
 
4
  from blossomtune_gradio import config as cfg
5
  from blossomtune_gradio import mail
6
  from blossomtune_gradio import util
7
  from blossomtune_gradio.settings import settings
8
+ from blossomtune_gradio.database import SessionLocal, Request, Config
9
 
10
 
11
  def generate_participant_id(length=6):
 
21
 
22
 
23
  def check_participant_status(pid_to_check: str, email: str, activation_code: str):
24
+ """
25
+ Handles a participant's request to join, activate, or check status using SQLAlchemy.
26
+ Returns a tuple: (is_approved: bool, message: str, data: any | None)
27
+ The 'is_approved' boolean is True ONLY when the participant's final status is 'approved'.
28
+ """
29
+ with SessionLocal() as db:
30
+ query = db.query(Request).filter(
31
+ Request.hf_handle == pid_to_check, Request.email == email
32
+ )
33
  if activation_code:
34
+ query = query.filter(Request.activation_code == activation_code)
 
 
 
 
 
 
 
 
 
35
 
36
+ request = query.first()
 
 
37
 
38
+ num_partitions_config = (
39
+ db.query(Config).filter(Config.key == "num_partitions").first()
40
+ )
41
+ num_partitions = num_partitions_config.value if num_partitions_config else "10"
42
+
43
+ # Case 1: New user registration
44
+ if request is None:
45
+ if activation_code:
46
+ return (False, settings.get_text("activation_invalid_md"), None)
47
+ if not util.validate_email(email):
48
+ return (False, settings.get_text("invalid_email_md"), None)
49
+
50
+ approved_count = (
51
+ db.query(Request).filter(Request.status == "approved").count()
52
+ )
53
  if approved_count >= cfg.MAX_NUM_NODES:
54
  return (False, settings.get_text("federation_full_md"), None)
55
 
56
+ participant_id = generate_participant_id()
57
+ new_activation_code = generate_activation_code()
58
+ mail_sent, message = mail.send_activation_email(email, new_activation_code)
59
+
60
+ if mail_sent:
61
+ new_request = Request(
62
+ participant_id=participant_id,
63
+ hf_handle=pid_to_check,
64
+ email=email,
65
+ activation_code=new_activation_code,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  )
67
+ db.add(new_request)
68
+ db.commit()
69
+ # A successful registration step, but not yet approved for federation.
70
+ return (False, settings.get_text("registration_submitted_md"), None)
71
+ else:
72
+ return (False, message, None)
73
+
74
+ # Case 2: User is activating their account
75
+ if not request.is_activated:
76
+ if activation_code == request.activation_code:
77
+ request.is_activated = 1
78
+ db.commit()
79
+ # A successful activation step, but not yet approved.
80
+ return (False, settings.get_text("activation_successful_md"), None)
81
+ else:
82
+ return (False, settings.get_text("activation_invalid_md"), None)
83
+
84
+ # At this point, user is activated.
85
+ # They must provide the activation code to check their final status.
86
  if not activation_code:
87
+ return (False, settings.get_text("missing_activation_code_md"), None)
88
+
89
+ # Case 3: Activated user is checking their final status
90
+ if request.status == "approved":
91
+ hostname = (
92
+ "localhost"
93
+ if not cfg.SPACE_ID
94
+ else f"{cfg.SPACE_ID.split('/')[1]}-{cfg.SPACE_ID.split('/')[0]}.hf.space"
95
+ )
96
+ superlink_hostname = cfg.SUPERLINK_HOST or hostname
97
+
98
+ connection_string = settings.get_text(
99
+ "status_approved_md",
100
+ participant_id=request.participant_id,
101
+ partition_id=request.partition_id,
102
+ superlink_hostname=superlink_hostname,
103
+ num_partitions=num_partitions,
104
+ )
105
+ # The user is fully approved. Return success and the cert path.
106
+ return (True, connection_string, cfg.BLOSSOMTUNE_TLS_CERT_PATH)
107
+ elif request.status == "pending":
108
+ return (False, settings.get_text("status_pending_md"), None)
109
+ else: # Denied
110
+ return (
111
+ False,
112
+ settings.get_text(
113
+ "status_denied_md", participant_id=request.participant_id
114
+ ),
115
+ None,
116
+ )
117
 
118
 
119
  def manage_request(participant_id: str, partition_id: str, action: str):
 
121
  if not participant_id:
122
  return False, "Please select a participant from the pending requests table."
123
 
124
+ with SessionLocal() as db:
125
+ request = (
126
+ db.query(Request).filter(Request.participant_id == participant_id).first()
127
+ )
128
+ if not request:
129
+ return False, "Participant not found."
130
 
131
+ if action == "approve":
132
+ if not partition_id or not partition_id.isdigit():
133
+ return False, "Please provide a valid integer for the Partition ID."
134
+
135
+ p_id_int = int(partition_id)
136
+ if not request.is_activated:
 
 
 
137
  return (
138
  False,
139
  settings.get_text("participant_not_activated_warning_md"),
140
  )
141
 
142
+ existing_participant = (
143
+ db.query(Request)
144
+ .filter(Request.partition_id == p_id_int, Request.status == "approved")
145
+ .first()
146
  )
147
+
148
+ if existing_participant:
149
  return (
150
  False,
151
  settings.get_text(
 
153
  ),
154
  )
155
 
156
+ request.status = "approved"
157
+ request.partition_id = p_id_int
158
+ db.commit()
 
159
  return (
160
  True,
161
  f"Participant {participant_id} is allowed to join the federation.",
162
  )
163
+ else: # Deny
164
+ request.status = "denied"
165
+ request.partition_id = None
166
+ db.commit()
 
 
167
  return (
168
  True,
169
  f"Participant {participant_id} is not allowed to join the federation.",
 
171
 
172
 
173
  def get_next_partion_id() -> int:
174
+ """Finds the lowest available partition ID."""
175
+ with SessionLocal() as db:
176
+ used_ids_query = (
177
+ db.query(Request.partition_id)
178
+ .filter(Request.status == "approved", Request.partition_id.isnot(None))
179
+ .all()
180
  )
181
+ used_ids = {row[0] for row in used_ids_query}
182
 
183
  next_id = 0
184
  while next_id in used_ids:
blossomtune_gradio/gradio_app.py CHANGED
@@ -1,12 +1,10 @@
1
  import gradio as gr
2
 
3
- from blossomtune_gradio import database as db
4
  from blossomtune_gradio.ui import components
5
  from blossomtune_gradio.ui import callbacks
6
 
7
 
8
  with gr.Blocks(theme=gr.themes.Soft(), title="Flower Superlink & Runner") as demo:
9
- db.init()
10
  gr.Markdown("BlossomTune 🌸 Flower Superlink & Runner")
11
  with gr.Row():
12
  login_button = gr.LoginButton()
 
1
  import gradio as gr
2
 
 
3
  from blossomtune_gradio.ui import components
4
  from blossomtune_gradio.ui import callbacks
5
 
6
 
7
  with gr.Blocks(theme=gr.themes.Soft(), title="Flower Superlink & Runner") as demo:
 
8
  gr.Markdown("BlossomTune 🌸 Flower Superlink & Runner")
9
  with gr.Row():
10
  login_button = gr.LoginButton()
blossomtune_gradio/processing.py CHANGED
@@ -1,11 +1,12 @@
1
  import os
2
  import shutil
3
- import sqlite3
4
  import threading
5
  import subprocess
6
 
7
  from blossomtune_gradio.logs import log
8
  from blossomtune_gradio import config as cfg
 
 
9
 
10
 
11
  # In-memory store for background processes and logs
@@ -44,7 +45,9 @@ def start_superlink():
44
 
45
  if process_store["superlink"] and process_store["superlink"].poll() is None:
46
  return False, "Superlink process is already running."
47
- command = [shutil.which("flower-superlink"), "--insecure"]
 
 
48
  threading.Thread(
49
  target=run_process, args=(command, "superlink"), daemon=True
50
  ).start()
@@ -58,31 +61,36 @@ def start_runner(
58
  ):
59
  if process_store["runner"] and process_store["runner"].poll() is None:
60
  return False, "A Runner process is already running."
61
- if (
62
- not (process_store["superlink"] and process_store["superlink"].poll() is None)
63
- and not cfg.SUPERLINK_MODE == "external"
64
- ):
 
 
65
  return (
66
  False,
67
- "Superlink is not running. Please start it before starting the runner.",
68
  )
69
- # TODO: check if external superlink is running
70
  if not all([runner_app, run_id, num_partitions]):
71
  return False, "Please provide a Runner App, Run ID, and Total Partitions."
72
  if not num_partitions.isdigit() or int(num_partitions) <= 0:
73
  return False, "Total Partitions must be a positive integer."
74
 
75
- with sqlite3.connect(cfg.DB_PATH) as conn:
76
- conn.execute(
77
- "UPDATE config SET value = ? WHERE key = 'num_partitions'",
78
- (num_partitions,),
79
- )
80
- conn.commit()
 
 
81
 
82
  runner_app_path = runner_app.replace(".", os.path.sep)
83
  if not os.path.exists(runner_app_path):
84
  return False, f"Unable to find app path '{runner_app_path}'."
85
 
 
86
  command = [
87
  shutil.which("flwr"),
88
  "run",
 
1
  import os
2
  import shutil
 
3
  import threading
4
  import subprocess
5
 
6
  from blossomtune_gradio.logs import log
7
  from blossomtune_gradio import config as cfg
8
+ from blossomtune_gradio import util
9
+ from blossomtune_gradio.database import SessionLocal, Config
10
 
11
 
12
  # In-memory store for background processes and logs
 
45
 
46
  if process_store["superlink"] and process_store["superlink"].poll() is None:
47
  return False, "Superlink process is already running."
48
+
49
+ # The command needs to be adapted for TLS if it's not insecure
50
+ command = [shutil.which("flower-superlink"), "--insecure"] # Placeholder
51
  threading.Thread(
52
  target=run_process, args=(command, "superlink"), daemon=True
53
  ).start()
 
61
  ):
62
  if process_store["runner"] and process_store["runner"].poll() is None:
63
  return False, "A Runner process is already running."
64
+
65
+ # Check if the Superlink is running, respecting the configured mode
66
+ if cfg.SUPERLINK_MODE == "external":
67
+ if not util.is_port_open(cfg.SUPERLINK_HOST, cfg.SUPERLINK_PORT):
68
+ return False, "External Superlink is not running or unreachable."
69
+ elif not (process_store["superlink"] and process_store["superlink"].poll() is None):
70
  return (
71
  False,
72
+ "Internal Superlink is not running. Please start it before starting the runner.",
73
  )
74
+
75
  if not all([runner_app, run_id, num_partitions]):
76
  return False, "Please provide a Runner App, Run ID, and Total Partitions."
77
  if not num_partitions.isdigit() or int(num_partitions) <= 0:
78
  return False, "Total Partitions must be a positive integer."
79
 
80
+ # Update the number of partitions in the database using SQLAlchemy
81
+ with SessionLocal() as db:
82
+ config_entry = db.query(Config).filter(Config.key == "num_partitions").first()
83
+ if config_entry:
84
+ config_entry.value = num_partitions
85
+ else:
86
+ db.add(Config(key="num_partitions", value=num_partitions))
87
+ db.commit()
88
 
89
  runner_app_path = runner_app.replace(".", os.path.sep)
90
  if not os.path.exists(runner_app_path):
91
  return False, f"Unable to find app path '{runner_app_path}'."
92
 
93
+ # Construct the command for a TLS-enabled runner
94
  command = [
95
  shutil.which("flwr"),
96
  "run",
blossomtune_gradio/ui/callbacks.py CHANGED
@@ -1,5 +1,4 @@
1
  import time
2
- import sqlite3
3
  import gradio as gr
4
  import pandas as pd
5
 
@@ -9,6 +8,7 @@ from blossomtune_gradio import federation as fed
9
  from blossomtune_gradio import processing
10
  from blossomtune_gradio.settings import settings
11
  from blossomtune_gradio import util
 
12
 
13
  from . import components
14
  from . import auth
@@ -46,16 +46,31 @@ def get_full_status_update(
46
  else:
47
  auth_status = settings.get_text("auth_status_local_mode_md")
48
 
49
- with sqlite3.connect(cfg.DB_PATH) as conn:
50
- pending_rows = conn.execute(
51
- "SELECT participant_id, hf_handle, email FROM requests WHERE status = 'pending' AND is_activated = 1 ORDER BY timestamp ASC"
52
- ).fetchall()
53
- approved_rows = conn.execute(
54
- "SELECT participant_id, hf_handle, email, partition_id FROM requests WHERE status = 'approved' ORDER BY timestamp DESC"
55
- ).fetchall()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
  # Superlink Status Logic
58
- superlink_btn_update = gr.update() # Default empty update
59
 
60
  if cfg.SUPERLINK_MODE == "internal":
61
  superlink_is_running = (
@@ -78,7 +93,6 @@ def get_full_status_update(
78
  else:
79
  is_open = util.is_port_open(cfg.SUPERLINK_HOST, cfg.SUPERLINK_PORT)
80
  superlink_status = "🟢 Running" if is_open else "🔴 Not Running"
81
- # Disable the button in external mode
82
  superlink_btn_update = gr.update(value="Managed Externally", interactive=False)
83
  else:
84
  superlink_status = "⚠️ Invalid Mode"
@@ -126,7 +140,6 @@ def toggle_superlink(
126
  ):
127
  """Toggles the Superlink process on or off."""
128
  if not auth.is_space_owner(profile, oauth_token):
129
- # Hardcode warning text as it's not in the schema
130
  gr.Warning("You are not authorized to perform this operation.")
131
  return
132
  if (
@@ -147,7 +160,6 @@ def toggle_runner(
147
  ):
148
  """Toggles the Runner process on or off."""
149
  if not auth.is_space_owner(profile, oauth_token):
150
- # Hardcode warning text as it's not in the schema
151
  gr.Warning("You are not authorized to perform this operation.")
152
  return
153
  if (
@@ -204,8 +216,8 @@ def on_check_participant_status(
204
  pid_to_check = user_hf_handle.strip()
205
  email_to_add = email.strip()
206
  activation_code_to_check = activation_code.strip()
207
- # The federation module is responsible for getting the correct text from settings
208
- _, message, download = fed.check_participant_status(
209
  pid_to_check, email_to_add, activation_code_to_check
210
  )
211
  return {
@@ -217,7 +229,6 @@ def on_check_participant_status(
217
 
218
 
219
  def on_manage_fed_request(participant_id: str, partition_id: str, action: str):
220
- # The federation module is responsible for getting the correct text from settings
221
  result, message = fed.manage_request(participant_id, partition_id, action)
222
  if result:
223
  gr.Info(message)
 
1
  import time
 
2
  import gradio as gr
3
  import pandas as pd
4
 
 
8
  from blossomtune_gradio import processing
9
  from blossomtune_gradio.settings import settings
10
  from blossomtune_gradio import util
11
+ from blossomtune_gradio.database import SessionLocal, Request
12
 
13
  from . import components
14
  from . import auth
 
46
  else:
47
  auth_status = settings.get_text("auth_status_local_mode_md")
48
 
49
+ with SessionLocal() as db:
50
+ pending_results = (
51
+ db.query(Request.participant_id, Request.hf_handle, Request.email)
52
+ .filter(Request.status == "pending", Request.is_activated == 1)
53
+ .order_by(Request.timestamp.asc())
54
+ .all()
55
+ )
56
+ approved_results = (
57
+ db.query(
58
+ Request.participant_id,
59
+ Request.hf_handle,
60
+ Request.email,
61
+ Request.partition_id,
62
+ )
63
+ .filter(Request.status == "approved")
64
+ .order_by(Request.timestamp.desc())
65
+ .all()
66
+ )
67
+
68
+ # Convert SQLAlchemy rows to simple lists
69
+ pending_rows = [list(row) for row in pending_results]
70
+ approved_rows = [list(row) for row in approved_results]
71
 
72
  # Superlink Status Logic
73
+ superlink_btn_update = gr.update()
74
 
75
  if cfg.SUPERLINK_MODE == "internal":
76
  superlink_is_running = (
 
93
  else:
94
  is_open = util.is_port_open(cfg.SUPERLINK_HOST, cfg.SUPERLINK_PORT)
95
  superlink_status = "🟢 Running" if is_open else "🔴 Not Running"
 
96
  superlink_btn_update = gr.update(value="Managed Externally", interactive=False)
97
  else:
98
  superlink_status = "⚠️ Invalid Mode"
 
140
  ):
141
  """Toggles the Superlink process on or off."""
142
  if not auth.is_space_owner(profile, oauth_token):
 
143
  gr.Warning("You are not authorized to perform this operation.")
144
  return
145
  if (
 
160
  ):
161
  """Toggles the Runner process on or off."""
162
  if not auth.is_space_owner(profile, oauth_token):
 
163
  gr.Warning("You are not authorized to perform this operation.")
164
  return
165
  if (
 
216
  pid_to_check = user_hf_handle.strip()
217
  email_to_add = email.strip()
218
  activation_code_to_check = activation_code.strip()
219
+
220
+ approved, message, download = fed.check_participant_status(
221
  pid_to_check, email_to_add, activation_code_to_check
222
  )
223
  return {
 
229
 
230
 
231
  def on_manage_fed_request(participant_id: str, partition_id: str, action: str):
 
232
  result, message = fed.manage_request(participant_id, partition_id, action)
233
  if result:
234
  gr.Info(message)
flower_apps/quickstart_huggingface/pyproject.toml CHANGED
@@ -60,3 +60,4 @@ options.backend.client-resources.num-gpus = 0.0 # at most 4 ClientApp will run i
60
  [tool.flwr.federations.local-deployment]
61
  address = "0.0.0.0:9093"
62
  insecure = true
 
 
60
  [tool.flwr.federations.local-deployment]
61
  address = "0.0.0.0:9093"
62
  insecure = true
63
+ root-certificate = ""
pyproject.toml CHANGED
@@ -16,6 +16,8 @@ dependencies = [
16
  "markupsafe==2.1.3",
17
  "jinja2>=3.1.6",
18
  "mlx[cpu]>=0.29.2",
 
 
19
  ]
20
 
21
  [tool.uv.sources]
@@ -90,3 +92,87 @@ build-backend = "setuptools.build_meta"
90
  where = ["."]
91
  include = ["blossomtune_gradio", "flower_apps"]
92
  exclude = ["results", "data"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  "markupsafe==2.1.3",
17
  "jinja2>=3.1.6",
18
  "mlx[cpu]>=0.29.2",
19
+ "alembic>=1.16.5",
20
+ "sqlalchemy>=2.0.43",
21
  ]
22
 
23
  [tool.uv.sources]
 
92
  where = ["."]
93
  include = ["blossomtune_gradio", "flower_apps"]
94
  exclude = ["results", "data"]
95
+
96
+
97
+ [tool.alembic]
98
+
99
+ # path to migration scripts.
100
+ # this is typically a path given in POSIX (e.g. forward slashes)
101
+ # format, relative to the token %(here)s which refers to the location of this
102
+ # ini file
103
+ script_location = "%(here)s/alembic"
104
+
105
+ # template used to generate migration file names; The default value is %%(rev)s_%%(slug)s
106
+ # Uncomment the line below if you want the files to be prepended with date and time
107
+ # see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file
108
+ # for all available tokens
109
+ # file_template = "%%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s"
110
+
111
+ # additional paths to be prepended to sys.path. defaults to the current working directory.
112
+ prepend_sys_path = [
113
+ "."
114
+ ]
115
+
116
+ # timezone to use when rendering the date within the migration file
117
+ # as well as the filename.
118
+ # If specified, requires the python>=3.9 or backports.zoneinfo library and tzdata library.
119
+ # Any required deps can installed by adding `alembic[tz]` to the pip requirements
120
+ # string value is passed to ZoneInfo()
121
+ # leave blank for localtime
122
+ # timezone =
123
+
124
+ # max length of characters to apply to the "slug" field
125
+ # truncate_slug_length = 40
126
+
127
+ # set to 'true' to run the environment during
128
+ # the 'revision' command, regardless of autogenerate
129
+ # revision_environment = false
130
+
131
+ # set to 'true' to allow .pyc and .pyo files without
132
+ # a source .py file to be detected as revisions in the
133
+ # versions/ directory
134
+ # sourceless = false
135
+
136
+ # version location specification; This defaults
137
+ # to <script_location>/versions. When using multiple version
138
+ # directories, initial revisions must be specified with --version-path.
139
+ # version_locations = [
140
+ # "%(here)s/alembic/versions",
141
+ # "%(here)s/foo/bar"
142
+ # ]
143
+
144
+
145
+ # set to 'true' to search source files recursively
146
+ # in each "version_locations" directory
147
+ # new in Alembic version 1.10
148
+ # recursive_version_locations = false
149
+
150
+ # the output encoding used when revision files
151
+ # are written from script.py.mako
152
+ # output_encoding = "utf-8"
153
+
154
+ # This section defines scripts or Python functions that are run
155
+ # on newly generated revision scripts. See the documentation for further
156
+ # detail and examples
157
+ # [[tool.alembic.post_write_hooks]]
158
+ # format using "black" - use the console_scripts runner,
159
+ # against the "black" entrypoint
160
+ # name = "black"
161
+ # type = "console_scripts"
162
+ # entrypoint = "black"
163
+ # options = "-l 79 REVISION_SCRIPT_FILENAME"
164
+ #
165
+ # [[tool.alembic.post_write_hooks]]
166
+ # lint with attempts to fix using "ruff" - use the module runner, against the "ruff" module
167
+ # name = "ruff"
168
+ # type = "module"
169
+ # module = "ruff"
170
+ # options = "check --fix REVISION_SCRIPT_FILENAME"
171
+ #
172
+ # [[tool.alembic.post_write_hooks]]
173
+ # Alternatively, use the exec runner to execute a binary found on your PATH
174
+ # name = "ruff"
175
+ # type = "exec"
176
+ # executable = "ruff"
177
+ # options = "check --fix REVISION_SCRIPT_FILENAME"
178
+
tests/conftest.py CHANGED
@@ -1,18 +1,60 @@
1
  import pytest
2
- import sqlite3
3
 
4
- from blossomtune_gradio import database
 
 
 
 
 
 
 
 
 
 
 
5
 
6
 
7
  @pytest.fixture
8
- def in_memory_db(mocker):
9
  """
10
- Fixture to set up and tear down an in-memory SQLite database for tests.
11
- It ensures that the same connection object is used for both schema
12
- initialization and the test execution.
13
  """
14
- con = sqlite3.connect(":memory:")
15
- mocker.patch("sqlite3.connect", return_value=con)
16
- database.init()
17
- yield con
18
- con.close()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import pytest
2
+ from unittest.mock import MagicMock
3
 
4
+ from alembic.config import Config
5
+ from alembic import command
6
+ from sqlalchemy import create_engine
7
+ from sqlalchemy.orm import sessionmaker
8
+
9
+ from blossomtune_gradio import config
10
+
11
+
12
+ @pytest.fixture(scope="session")
13
+ def alembic_config():
14
+ """Fixture to create a valid Alembic Config object."""
15
+ return Config("alembic.ini")
16
 
17
 
18
  @pytest.fixture
19
+ def db_session(mocker, tmp_path):
20
  """
21
+ Fixture to set up a clean file-based SQLite database for each test function.
22
+ It creates the database schema using Alembic programmatically and ensures all
23
+ modules use the same test database session.
24
  """
25
+ db_file = tmp_path / "test_federation.db"
26
+ db_url = f"sqlite:///{db_file}"
27
+ mocker.patch.object(config, "SQLALCHEMY_URL", db_url)
28
+
29
+ # Create an Alembic Config object and point it to the temp database.
30
+ alembic_cfg = Config()
31
+ alembic_cfg.set_main_option("script_location", "alembic")
32
+ alembic_cfg.set_main_option("sqlalchemy.url", db_url)
33
+
34
+ # Apply the migrations to create the schema in the temporary database.
35
+ command.upgrade(alembic_cfg, "head")
36
+
37
+ # Set up the SQLAlchemy engine and session factory for the tests to use.
38
+ engine = create_engine(db_url)
39
+ TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
40
+ session = TestingSessionLocal()
41
+
42
+ # Mock the SessionLocal factory in each module where it is imported and used.
43
+ mocker.patch("blossomtune_gradio.federation.SessionLocal", return_value=session)
44
+ mocker.patch("blossomtune_gradio.processing.SessionLocal", return_value=session)
45
+ mocker.patch("blossomtune_gradio.ui.callbacks.SessionLocal", return_value=session)
46
+
47
+ yield session
48
+
49
+ session.close()
50
+
51
+
52
+ @pytest.fixture
53
+ def mock_settings(mocker):
54
+ """Fixture to mock the settings module, available to all tests."""
55
+ mock_get = MagicMock(
56
+ side_effect=lambda key, **kwargs: f"mock_{key}".format(**kwargs)
57
+ )
58
+ mocker.patch("blossomtune_gradio.federation.settings.get_text", mock_get)
59
+ mocker.patch("blossomtune_gradio.ui.callbacks.settings.get_text", mock_get)
60
+ return mock_get
tests/test_federation.py CHANGED
@@ -1,18 +1,8 @@
1
  import pytest
2
- from unittest.mock import MagicMock
3
 
4
  from blossomtune_gradio import federation as fed
5
-
6
-
7
- @pytest.fixture
8
- def mock_settings(mocker):
9
- """Fixture to mock the settings module."""
10
- # The lambda returns a formatted string to simulate Jinja2's behavior
11
- mock_get = MagicMock(
12
- side_effect=lambda key, **kwargs: f"mock_{key}".format(**kwargs)
13
- )
14
- mocker.patch("blossomtune_gradio.federation.settings.get_text", mock_get)
15
- return mock_get
16
 
17
 
18
  @pytest.fixture
@@ -38,208 +28,202 @@ def test_generate_activation_code():
38
 
39
 
40
  class TestCheckParticipantStatus:
41
- """Test suite for the check_participant_status function."""
42
 
43
- def test_new_user_registration_success(
44
- self, in_memory_db, mock_settings, mock_mail
45
- ):
46
  """Verify successful registration for a new user."""
47
  mock_mail.return_value = (True, "")
48
- success, message, download = fed.check_participant_status(
49
  "new_user", "new@example.com", ""
50
  )
51
- assert success is True
52
  assert download is None
53
  assert message == "mock_registration_submitted_md"
54
 
55
  # Verify the user was added to the database
56
- cursor = in_memory_db.cursor()
57
- cursor.execute("SELECT hf_handle FROM requests WHERE hf_handle = 'new_user'")
58
- assert cursor.fetchone() is not None
59
 
60
- def test_new_user_invalid_email(self, in_memory_db, mock_settings):
61
  """Verify registration fails with an invalid email."""
62
- success, message, download = fed.check_participant_status(
63
  "user", "invalid-email", ""
64
  )
65
- assert success is False
66
  assert download is None
67
  assert message == "mock_invalid_email_md"
68
 
69
- def test_new_user_federation_full(self, in_memory_db, mock_settings, mocker):
70
  """Verify registration fails when the federation is full."""
71
- mocker.patch("blossomtune_gradio.federation.cfg.MAX_NUM_NODES", 0)
72
- success, message, download = fed.check_participant_status(
73
  "another_user", "another@example.com", ""
74
  )
75
- assert success is False
76
  assert download is None
77
  assert message == "mock_federation_full_md"
78
 
79
- def test_user_activation_success(self, in_memory_db, mock_settings):
80
  """Verify a user can successfully activate their account."""
81
  # Setup: Add a pending, non-activated user
82
- cursor = in_memory_db.cursor()
83
- cursor.execute(
84
- "INSERT INTO requests (participant_id, status, timestamp, hf_handle, email, activation_code, is_activated) VALUES (?, ?, ?, ?, ?, ?, ?)",
85
- ("PID123", "pending", "now", "test_user", "test@example.com", "ABCDEF", 0),
 
 
86
  )
87
- in_memory_db.commit()
 
88
 
89
- success, message, download = fed.check_participant_status(
90
  "test_user", "test@example.com", "ABCDEF"
91
  )
92
- assert success is True
93
  assert download is None
94
  assert message == "mock_activation_successful_md"
95
  # Verify the user is now activated
96
- cursor.execute(
97
- "SELECT is_activated FROM requests WHERE hf_handle = 'test_user'"
98
  )
99
- assert cursor.fetchone()[0] == 1
100
 
101
- def test_user_activation_invalid_code(self, in_memory_db, mock_settings):
102
  """Verify activation fails with an invalid code."""
103
- # Setup: Add a pending, non-activated user
104
- cursor = in_memory_db.cursor()
105
- cursor.execute(
106
- "INSERT INTO requests (participant_id, status, timestamp, hf_handle, email, activation_code, is_activated) VALUES (?, ?, ?, ?, ?, ?, ?)",
107
- ("PID123", "pending", "now", "test_user", "test@example.com", "ABCDEF", 0),
 
108
  )
109
- in_memory_db.commit()
 
110
 
111
- success, message, download = fed.check_participant_status(
112
  "test_user", "test@example.com", "WRONGCODE"
113
  )
114
- assert success is False
115
  assert download is None
116
  assert message == "mock_activation_invalid_md"
117
 
118
- def test_status_check_approved(self, in_memory_db, mock_settings):
119
  """Verify the status check for an approved user."""
120
- cursor = in_memory_db.cursor()
121
- cursor.execute(
122
- "INSERT INTO requests (participant_id, status, timestamp, hf_handle, email, activation_code, is_activated, partition_id) VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
123
- (
124
- "PID456",
125
- "approved",
126
- "now",
127
- "approved_user",
128
- "approved@example.com",
129
- "GHIJKL",
130
- 1,
131
- 5,
132
- ),
133
- )
134
- in_memory_db.commit()
135
- success, message, download = fed.check_participant_status(
136
  "approved_user", "approved@example.com", "GHIJKL"
137
  )
138
- assert success is True
139
  assert download is not None
140
  assert "mock_status_approved_md" in message
141
 
142
 
143
  class TestManageRequest:
144
- """Test suite for the manage_request function."""
145
 
146
- def test_approve_success(self, in_memory_db):
147
  """Verify successful approval of a participant."""
148
- # Setup: Add an activated, pending user
149
- cursor = in_memory_db.cursor()
150
- cursor.execute(
151
- "INSERT INTO requests (participant_id, status, timestamp, hf_handle, email, activation_code, is_activated) VALUES (?, ?, ?, ?, ?, ?, ?)",
152
- (
153
- "PENDING1",
154
- "pending",
155
- "now",
156
- "pending_user",
157
- "pending@example.com",
158
- "CODE",
159
- 1,
160
- ),
161
- )
162
- in_memory_db.commit()
163
 
164
  success, message = fed.manage_request("PENDING1", "10", "approve")
165
  assert success is True
166
  assert "is allowed to join" in message
167
 
168
  # Verify status in DB
169
- cursor.execute(
170
- "SELECT status, partition_id FROM requests WHERE participant_id = 'PENDING1'"
171
  )
172
- status, partition_id = cursor.fetchone()
173
- assert status == "approved"
174
- assert partition_id == 10
175
 
176
- def test_approve_not_activated(self, in_memory_db, mock_settings):
177
  """Verify approval fails if the user is not activated."""
178
- cursor = in_memory_db.cursor()
179
- cursor.execute(
180
- "INSERT INTO requests (participant_id, status, timestamp, hf_handle, email, activation_code, is_activated) VALUES (?, ?, ?, ?, ?, ?, ?)",
181
- (
182
- "PENDING2",
183
- "pending",
184
- "now",
185
- "pending_user2",
186
- "pending2@example.com",
187
- "CODE",
188
- 0,
189
- ),
190
- )
191
- in_memory_db.commit()
192
  success, message = fed.manage_request("PENDING2", "11", "approve")
193
  assert success is False
194
  assert message == "mock_participant_not_activated_warning_md"
195
 
196
- def test_deny_success(self, in_memory_db):
197
  """Verify successful denial of a participant."""
198
- cursor = in_memory_db.cursor()
199
- cursor.execute(
200
- "INSERT INTO requests (participant_id, status, timestamp, hf_handle, email, activation_code, is_activated) VALUES (?, ?, ?, ?, ?, ?, ?)",
201
- (
202
- "PENDING3",
203
- "pending",
204
- "now",
205
- "pending_user3",
206
- "pending3@example.com",
207
- "CODE",
208
- 1,
209
- ),
210
- )
211
- in_memory_db.commit()
212
  success, message = fed.manage_request("PENDING3", "", "deny")
213
  assert success is True
214
  assert "is not allowed to join" in message
215
 
216
  # Verify status in DB
217
- cursor.execute("SELECT status FROM requests WHERE participant_id = 'PENDING3'")
218
- assert cursor.fetchone()[0] == "denied"
 
 
219
 
220
 
221
- def test_get_next_partition_id(in_memory_db):
222
  """Verify the logic for finding the next available partition ID."""
223
- cursor = in_memory_db.cursor()
224
  # No approved users yet
225
  assert fed.get_next_partion_id() == 0
226
 
227
- # Add some approved users with assigned partitions, including the required timestamp
228
- cursor.execute(
229
- "INSERT INTO requests (participant_id, status, timestamp, partition_id) VALUES (?, ?, ?, ?)",
230
- ("P1", "approved", "now", 0),
 
 
 
 
231
  )
232
- cursor.execute(
233
- "INSERT INTO requests (participant_id, status, timestamp, partition_id) VALUES (?, ?, ?, ?)",
234
- ("P2", "approved", "now", 1),
 
 
 
 
235
  )
236
- in_memory_db.commit()
237
  assert fed.get_next_partion_id() == 2
238
 
239
  # Add another user, skipping a partition ID
240
- cursor.execute(
241
- "INSERT INTO requests (participant_id, status, timestamp, partition_id) VALUES (?, ?, ?, ?)",
242
- ("P3", "approved", "now", 3),
 
 
 
 
243
  )
244
- in_memory_db.commit()
245
  assert fed.get_next_partion_id() == 2
 
1
  import pytest
2
+ from datetime import datetime
3
 
4
  from blossomtune_gradio import federation as fed
5
+ from blossomtune_gradio.database import Request
 
 
 
 
 
 
 
 
 
 
6
 
7
 
8
  @pytest.fixture
 
28
 
29
 
30
  class TestCheckParticipantStatus:
31
+ """Test suite for the check_participant_status function using SQLAlchemy."""
32
 
33
+ def test_new_user_registration_success(self, db_session, mock_settings, mock_mail):
 
 
34
  """Verify successful registration for a new user."""
35
  mock_mail.return_value = (True, "")
36
+ approved, message, download = fed.check_participant_status(
37
  "new_user", "new@example.com", ""
38
  )
39
+ assert approved is False
40
  assert download is None
41
  assert message == "mock_registration_submitted_md"
42
 
43
  # Verify the user was added to the database
44
+ request = db_session.query(Request).filter_by(hf_handle="new_user").first()
45
+ assert request is not None
46
+ assert request.email == "new@example.com"
47
 
48
+ def test_new_user_invalid_email(self, db_session, mock_settings):
49
  """Verify registration fails with an invalid email."""
50
+ approved, message, download = fed.check_participant_status(
51
  "user", "invalid-email", ""
52
  )
53
+ assert approved is False
54
  assert download is None
55
  assert message == "mock_invalid_email_md"
56
 
57
+ def test_new_user_federation_full(self, db_session, mock_settings, mocker):
58
  """Verify registration fails when the federation is full."""
59
+ mocker.patch("blossomtune_gradio.config.MAX_NUM_NODES", 0)
60
+ approved, message, download = fed.check_participant_status(
61
  "another_user", "another@example.com", ""
62
  )
63
+ assert approved is False
64
  assert download is None
65
  assert message == "mock_federation_full_md"
66
 
67
+ def test_user_activation_success(self, db_session, mock_settings):
68
  """Verify a user can successfully activate their account."""
69
  # Setup: Add a pending, non-activated user
70
+ pending_user = Request(
71
+ participant_id="PID123",
72
+ hf_handle="test_user",
73
+ email="test@example.com",
74
+ activation_code="ABCDEF",
75
+ is_activated=0,
76
  )
77
+ db_session.add(pending_user)
78
+ db_session.commit()
79
 
80
+ approved, message, download = fed.check_participant_status(
81
  "test_user", "test@example.com", "ABCDEF"
82
  )
83
+ assert approved is False
84
  assert download is None
85
  assert message == "mock_activation_successful_md"
86
  # Verify the user is now activated
87
+ activated_user = (
88
+ db_session.query(Request).filter_by(hf_handle="test_user").first()
89
  )
90
+ assert activated_user.is_activated == 1
91
 
92
+ def test_user_activation_invalid_code(self, db_session, mock_settings):
93
  """Verify activation fails with an invalid code."""
94
+ pending_user = Request(
95
+ participant_id="PID123",
96
+ hf_handle="test_user",
97
+ email="test@example.com",
98
+ activation_code="ABCDEF",
99
+ is_activated=0,
100
  )
101
+ db_session.add(pending_user)
102
+ db_session.commit()
103
 
104
+ approved, message, download = fed.check_participant_status(
105
  "test_user", "test@example.com", "WRONGCODE"
106
  )
107
+ assert approved is False
108
  assert download is None
109
  assert message == "mock_activation_invalid_md"
110
 
111
+ def test_status_check_approved(self, db_session, mock_settings):
112
  """Verify the status check for an approved user."""
113
+ approved_user = Request(
114
+ participant_id="PID456",
115
+ status="approved",
116
+ hf_handle="approved_user",
117
+ email="approved@example.com",
118
+ activation_code="GHIJKL",
119
+ is_activated=1,
120
+ partition_id=5,
121
+ )
122
+ db_session.add(approved_user)
123
+ db_session.commit()
124
+
125
+ approved, message, download = fed.check_participant_status(
 
 
 
126
  "approved_user", "approved@example.com", "GHIJKL"
127
  )
128
+ assert approved is True
129
  assert download is not None
130
  assert "mock_status_approved_md" in message
131
 
132
 
133
  class TestManageRequest:
134
+ """Test suite for the manage_request function using SQLAlchemy."""
135
 
136
+ def test_approve_success(self, db_session):
137
  """Verify successful approval of a participant."""
138
+ pending_user = Request(
139
+ participant_id="PENDING1",
140
+ status="pending",
141
+ hf_handle="pending_user",
142
+ email="pending@example.com",
143
+ is_activated=1,
144
+ )
145
+ db_session.add(pending_user)
146
+ db_session.commit()
 
 
 
 
 
 
147
 
148
  success, message = fed.manage_request("PENDING1", "10", "approve")
149
  assert success is True
150
  assert "is allowed to join" in message
151
 
152
  # Verify status in DB
153
+ updated_user = (
154
+ db_session.query(Request).filter_by(participant_id="PENDING1").first()
155
  )
156
+ assert updated_user.status == "approved"
157
+ assert updated_user.partition_id == 10
 
158
 
159
+ def test_approve_not_activated(self, db_session, mock_settings):
160
  """Verify approval fails if the user is not activated."""
161
+ pending_user = Request(
162
+ participant_id="PENDING2",
163
+ status="pending",
164
+ hf_handle="pending_user2",
165
+ is_activated=0,
166
+ )
167
+ db_session.add(pending_user)
168
+ db_session.commit()
 
 
 
 
 
 
169
  success, message = fed.manage_request("PENDING2", "11", "approve")
170
  assert success is False
171
  assert message == "mock_participant_not_activated_warning_md"
172
 
173
+ def test_deny_success(self, db_session):
174
  """Verify successful denial of a participant."""
175
+ pending_user = Request(
176
+ participant_id="PENDING3",
177
+ status="pending",
178
+ hf_handle="pending_user3",
179
+ is_activated=1,
180
+ )
181
+ db_session.add(pending_user)
182
+ db_session.commit()
 
 
 
 
 
 
183
  success, message = fed.manage_request("PENDING3", "", "deny")
184
  assert success is True
185
  assert "is not allowed to join" in message
186
 
187
  # Verify status in DB
188
+ updated_user = (
189
+ db_session.query(Request).filter_by(participant_id="PENDING3").first()
190
+ )
191
+ assert updated_user.status == "denied"
192
 
193
 
194
+ def test_get_next_partition_id(db_session):
195
  """Verify the logic for finding the next available partition ID."""
 
196
  # No approved users yet
197
  assert fed.get_next_partion_id() == 0
198
 
199
+ # Add some approved users with assigned partitions
200
+ db_session.add(
201
+ Request(
202
+ participant_id="P1",
203
+ status="approved",
204
+ partition_id=0,
205
+ timestamp=datetime.utcnow(),
206
+ )
207
  )
208
+ db_session.add(
209
+ Request(
210
+ participant_id="P2",
211
+ status="approved",
212
+ partition_id=1,
213
+ timestamp=datetime.utcnow(),
214
+ )
215
  )
216
+ db_session.commit()
217
  assert fed.get_next_partion_id() == 2
218
 
219
  # Add another user, skipping a partition ID
220
+ db_session.add(
221
+ Request(
222
+ participant_id="P3",
223
+ status="approved",
224
+ partition_id=3,
225
+ timestamp=datetime.utcnow(),
226
+ )
227
  )
228
+ db_session.commit()
229
  assert fed.get_next_partion_id() == 2
tests/test_processing.py CHANGED
@@ -2,6 +2,7 @@ import pytest
2
  from unittest.mock import MagicMock, patch
3
 
4
  from blossomtune_gradio import processing
 
5
 
6
 
7
  @pytest.fixture(autouse=True)
@@ -12,8 +13,6 @@ def reset_process_store():
12
  """
13
  processing.process_store = {"superlink": None, "runner": None}
14
  yield
15
- # Teardown is not strictly necessary as it's reset at the start,
16
- # but it's good practice.
17
  processing.process_store = {"superlink": None, "runner": None}
18
 
19
 
@@ -29,18 +28,15 @@ def test_start_superlink_success(mock_which, mock_thread):
29
  assert success is True
30
  assert message == "Superlink process started."
31
  mock_thread.assert_called_once()
32
- # Check that the thread is targeting the run_process function
33
  call_args = mock_thread.call_args
34
  assert call_args.kwargs["target"] == processing.run_process
35
- # Check that the command is correct
36
  assert call_args.kwargs["args"][0] == ["/fake/path/flower-superlink", "--insecure"]
37
 
38
 
39
  def test_start_superlink_already_running(mocker):
40
  """Verify that start_superlink returns False if a process is already running."""
41
- # Mock a running process
42
  mock_process = MagicMock()
43
- mock_process.poll.return_value = None # None means it's running
44
  processing.process_store["superlink"] = mock_process
45
 
46
  success, message = processing.start_superlink()
@@ -49,33 +45,60 @@ def test_start_superlink_already_running(mocker):
49
 
50
 
51
  @patch("blossomtune_gradio.processing.os.path.exists", return_value=True)
52
- @patch("blossomtune_gradio.processing.sqlite3.connect")
53
  @patch("blossomtune_gradio.processing.threading.Thread")
54
  @patch("blossomtune_gradio.processing.shutil.which", return_value="/fake/path/flwr")
55
- def test_start_runner_success(mock_which, mock_thread, mock_sqlite, mock_exists):
56
- """Verify that start_runner successfully starts a thread when conditions are met."""
57
- # Mock a running superlink process
 
 
58
  mock_superlink = MagicMock()
59
  mock_superlink.poll.return_value = None
60
  processing.process_store["superlink"] = mock_superlink
61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  success, message = processing.start_runner("app.main", "run1", "10")
63
 
 
64
  assert success is True
65
  assert message == "Federation Run is starting...."
66
  mock_thread.assert_called_once()
67
- mock_sqlite.assert_called_once() # Check if DB was updated
68
 
69
 
70
- def test_start_runner_superlink_not_running():
71
- """Verify start_runner fails if superlink is not running."""
72
  processing.process_store["superlink"] = None
73
  success, message = processing.start_runner("app.main", "run1", "10")
74
  assert success is False
75
- assert "Superlink is not running" in message
76
 
77
 
78
- def test_start_runner_missing_args():
79
  """Verify start_runner fails if arguments are missing."""
80
  mock_superlink = MagicMock()
81
  mock_superlink.poll.return_value = None
@@ -87,7 +110,7 @@ def test_start_runner_missing_args():
87
 
88
 
89
  @patch("blossomtune_gradio.processing.os.path.exists", return_value=False)
90
- def test_start_runner_app_path_not_found(mock_exists, in_memory_db):
91
  """Verify start_runner fails if the app path doesn't exist."""
92
  mock_superlink = MagicMock()
93
  mock_superlink.poll.return_value = None
@@ -101,7 +124,7 @@ def test_start_runner_app_path_not_found(mock_exists, in_memory_db):
101
  def test_stop_process_running():
102
  """Verify stop_process terminates a running process."""
103
  mock_process = MagicMock()
104
- mock_process.poll.return_value = None # Running
105
  processing.process_store["superlink"] = mock_process
106
 
107
  processing.stop_process("superlink")
@@ -118,7 +141,6 @@ def test_stop_process_not_running(mocker):
118
 
119
  processing.stop_process("superlink")
120
 
121
- # Check that the specific "no process was running" log was called
122
  log_mock.assert_any_call(
123
  "[Superlink] Stop command received, but no process was running."
124
  )
 
2
  from unittest.mock import MagicMock, patch
3
 
4
  from blossomtune_gradio import processing
5
+ from blossomtune_gradio.database import Config
6
 
7
 
8
  @pytest.fixture(autouse=True)
 
13
  """
14
  processing.process_store = {"superlink": None, "runner": None}
15
  yield
 
 
16
  processing.process_store = {"superlink": None, "runner": None}
17
 
18
 
 
28
  assert success is True
29
  assert message == "Superlink process started."
30
  mock_thread.assert_called_once()
 
31
  call_args = mock_thread.call_args
32
  assert call_args.kwargs["target"] == processing.run_process
 
33
  assert call_args.kwargs["args"][0] == ["/fake/path/flower-superlink", "--insecure"]
34
 
35
 
36
  def test_start_superlink_already_running(mocker):
37
  """Verify that start_superlink returns False if a process is already running."""
 
38
  mock_process = MagicMock()
39
+ mock_process.poll.return_value = None
40
  processing.process_store["superlink"] = mock_process
41
 
42
  success, message = processing.start_superlink()
 
45
 
46
 
47
  @patch("blossomtune_gradio.processing.os.path.exists", return_value=True)
 
48
  @patch("blossomtune_gradio.processing.threading.Thread")
49
  @patch("blossomtune_gradio.processing.shutil.which", return_value="/fake/path/flwr")
50
+ def test_start_runner_success_internal_superlink(
51
+ mock_which, mock_thread, mock_exists, db_session
52
+ ):
53
+ """Verify start_runner succeeds with an internal superlink and updates the DB."""
54
+ # Arrange: Mock a running internal superlink process
55
  mock_superlink = MagicMock()
56
  mock_superlink.poll.return_value = None
57
  processing.process_store["superlink"] = mock_superlink
58
 
59
+ # Act
60
+ success, message = processing.start_runner("app.main", "run1", "15")
61
+
62
+ # Assert
63
+ assert success is True
64
+ assert message == "Federation Run is starting...."
65
+ mock_thread.assert_called_once()
66
+
67
+ # Verify DB was updated using SQLAlchemy
68
+ config_entry = db_session.query(Config).filter_by(key="num_partitions").first()
69
+ assert config_entry is not None
70
+ assert config_entry.value == "15"
71
+
72
+
73
+ @patch("blossomtune_gradio.processing.os.path.exists", return_value=True)
74
+ @patch("blossomtune_gradio.processing.threading.Thread")
75
+ @patch("blossomtune_gradio.processing.shutil.which", return_value="/fake/path/flwr")
76
+ def test_start_runner_success_external_superlink(
77
+ mock_which, mock_thread, mock_exists, db_session, mocker
78
+ ):
79
+ """Verify start_runner succeeds with an external superlink."""
80
+ # Arrange
81
+ mocker.patch("blossomtune_gradio.config.SUPERLINK_MODE", "external")
82
+ mocker.patch("blossomtune_gradio.util.is_port_open", return_value=True)
83
+
84
+ # Act
85
  success, message = processing.start_runner("app.main", "run1", "10")
86
 
87
+ # Assert
88
  assert success is True
89
  assert message == "Federation Run is starting...."
90
  mock_thread.assert_called_once()
 
91
 
92
 
93
+ def test_start_runner_internal_superlink_not_running(db_session):
94
+ """Verify start_runner fails if internal superlink is not running."""
95
  processing.process_store["superlink"] = None
96
  success, message = processing.start_runner("app.main", "run1", "10")
97
  assert success is False
98
+ assert "Internal Superlink is not running" in message
99
 
100
 
101
+ def test_start_runner_missing_args(db_session):
102
  """Verify start_runner fails if arguments are missing."""
103
  mock_superlink = MagicMock()
104
  mock_superlink.poll.return_value = None
 
110
 
111
 
112
  @patch("blossomtune_gradio.processing.os.path.exists", return_value=False)
113
+ def test_start_runner_app_path_not_found(mock_exists, db_session):
114
  """Verify start_runner fails if the app path doesn't exist."""
115
  mock_superlink = MagicMock()
116
  mock_superlink.poll.return_value = None
 
124
  def test_stop_process_running():
125
  """Verify stop_process terminates a running process."""
126
  mock_process = MagicMock()
127
+ mock_process.poll.return_value = None
128
  processing.process_store["superlink"] = mock_process
129
 
130
  processing.stop_process("superlink")
 
141
 
142
  processing.stop_process("superlink")
143
 
 
144
  log_mock.assert_any_call(
145
  "[Superlink] Stop command received, but no process was running."
146
  )
uv.lock CHANGED
The diff for this file is too large to render. See raw diff