mrs83 commited on
Commit
bc67f56
·
unverified ·
1 Parent(s): 905e1ea

Add Elliptic Curve (EC) Authentication (#21)

Browse files

* Adds a single public_key_pem column to the database.
* Add --auth-list-public-keys to docker entrypoint's superlink cmd
* Add --auth-list-public-keys to superlink subprocess cmd
* Add blossomfile generation
* Handle custom csv file format support
* Convert PEM-formatted key to ssh format
* Fix tests and dependencies
* Fix test_rebuild_overwrites_existing_file
* Add requirements.txt for custom HF space install

alembic/versions/e7169fe29ea1_add_public_key_pem_to_request_table.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Add public_key_pem to Request table.
2
+
3
+ Revision ID: e7169fe29ea1
4
+ Revises: 0b01b44ab005
5
+ Create Date: 2025-10-09 13:23:30.141015
6
+
7
+ """
8
+
9
+ from typing import Sequence, Union
10
+
11
+ from alembic import op
12
+ import sqlalchemy as sa
13
+
14
+
15
+ # revision identifiers, used by Alembic.
16
+ revision: str = "e7169fe29ea1"
17
+ down_revision: Union[str, Sequence[str], None] = "0b01b44ab005"
18
+ branch_labels: Union[str, Sequence[str], None] = None
19
+ depends_on: Union[str, Sequence[str], None] = None
20
+
21
+
22
+ def upgrade() -> None:
23
+ """Upgrade schema."""
24
+ # ### commands auto generated by Alembic - please adjust! ###
25
+ op.add_column("requests", sa.Column("public_key_pem", sa.String(), nullable=True))
26
+ # ### end Alembic commands ###
27
+
28
+
29
+ def downgrade() -> None:
30
+ """Downgrade schema."""
31
+ # ### commands auto generated by Alembic - please adjust! ###
32
+ op.drop_column("requests", "public_key_pem")
33
+ # ### end Alembic commands ###
blossomtune_gradio/auth_keys.py CHANGED
@@ -1,39 +1,69 @@
1
  import os
2
- import csv
3
  import logging
4
  from typing import List, Tuple
5
  from cryptography.hazmat.primitives.asymmetric import ec
6
  from cryptography.hazmat.primitives import serialization
7
 
 
8
  log = logging.getLogger(__name__)
9
 
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  def rebuild_authorized_keys_csv(
12
  key_dir: str, authorized_participants: List[Tuple[str, str]]
13
  ):
14
  """
15
- Overwrites the public key CSV with a fresh list from a trusted source.
16
-
17
- This function should be called before starting the Flower SuperLink to ensure
18
- the list of authorized nodes is always in sync with the database.
19
-
20
- Args:
21
- key_dir (str): The directory where the CSV will be stored.
22
- authorized_participants (List[Tuple[str, str]]): A list of tuples,
23
- where each tuple contains (participant_id, public_key_pem).
24
  """
25
  csv_path = os.path.join(key_dir, "authorized_supernodes.csv")
26
- log.info(f"Rebuilding authorized keys CSV at: {csv_path}")
27
 
28
- with open(csv_path, "w", newline="") as f:
29
- writer = csv.writer(f)
30
- writer.writerow(["participant_id", "public_key_pem"])
31
- for participant_id, public_key_pem in authorized_participants:
32
- writer.writerow([participant_id, public_key_pem])
 
33
 
34
- log.info(
35
- f"Successfully rebuilt {csv_path} with {len(authorized_participants)} keys."
36
- )
 
 
 
 
 
37
 
38
 
39
  class AuthKeyGenerator:
@@ -43,12 +73,6 @@ class AuthKeyGenerator:
43
  """
44
 
45
  def __init__(self, key_dir: str = "keys"):
46
- """
47
- Initializes the generator and ensures the key directory exists.
48
-
49
- Args:
50
- key_dir (str): The directory where key files will be stored.
51
- """
52
  self.key_dir = key_dir
53
  os.makedirs(self.key_dir, exist_ok=True)
54
  log.info(f"Authentication key directory set to: {self.key_dir}")
@@ -70,24 +94,17 @@ class AuthKeyGenerator:
70
  encryption_algorithm=serialization.NoEncryption(),
71
  )
72
  )
73
- # Set file permissions to read/write for owner only (600)
74
  os.chmod(priv_key_path, 0o600)
75
  log.info(f"Private key for {participant_id} saved securely to {priv_key_path}")
76
  return priv_key_path
77
 
78
  def _save_public_key_file(
79
- self, private_key: ec.EllipticCurvePrivateKey, participant_id: str
80
  ) -> str:
81
- """Saves the public key to a .pub file."""
82
- public_key = private_key.public_key()
83
  pub_key_path = os.path.join(self.key_dir, f"{participant_id}.pub")
84
- with open(pub_key_path, "wb") as f:
85
- f.write(
86
- public_key.public_bytes(
87
- encoding=serialization.Encoding.PEM,
88
- format=serialization.PublicFormat.SubjectPublicKeyInfo,
89
- )
90
- )
91
  log.info(f"Public key for {participant_id} saved to {pub_key_path}")
92
  return pub_key_path
93
 
@@ -95,26 +112,25 @@ class AuthKeyGenerator:
95
  """
96
  Generates and saves a new EC key pair for a participant.
97
 
98
- This is the main public method to call when a new participant is approved.
99
-
100
- Args:
101
- participant_id (str): A unique identifier for the participant.
102
-
103
  Returns:
104
  A tuple containing:
105
  - The file path to the generated private key.
106
  - The file path to the generated public key.
107
- - The public key as a PEM-encoded string (for database storage).
108
  """
109
  private_key = self._generate_key_pair()
110
  public_key = private_key.public_key()
111
 
112
- priv_key_path = self._save_private_key(private_key, participant_id)
113
- pub_key_path = self._save_public_key_file(private_key, participant_id)
114
-
115
- public_key_pem = public_key.public_bytes(
116
- encoding=serialization.Encoding.PEM,
117
- format=serialization.PublicFormat.SubjectPublicKeyInfo,
118
  ).decode("utf-8")
119
 
120
- return (priv_key_path, pub_key_path, public_key_pem)
 
 
 
 
 
 
 
1
  import os
 
2
  import logging
3
  from typing import List, Tuple
4
  from cryptography.hazmat.primitives.asymmetric import ec
5
  from cryptography.hazmat.primitives import serialization
6
 
7
+ # Configure logging for the module
8
  log = logging.getLogger(__name__)
9
 
10
 
11
+ def _sanitize_key(participant_id: str, key_str: str) -> str | None:
12
+ """
13
+ Inspects a key string and converts it to the required OpenSSH format if necessary.
14
+ This provides resilience against old, PEM-formatted keys in the database.
15
+ """
16
+ if not key_str:
17
+ return None
18
+ # If the key is already in the correct OpenSSH format, return it as is.
19
+ if key_str.startswith("ecdsa-sha2-nistp384"):
20
+ return key_str
21
+ # If the key is in the old PEM format, attempt to convert it.
22
+ if "-----BEGIN PUBLIC KEY-----" in key_str:
23
+ log.warning(
24
+ f"Found PEM-formatted key for participant {participant_id}. Converting to OpenSSH."
25
+ )
26
+ try:
27
+ public_key = serialization.load_pem_public_key(key_str.encode("utf-8"))
28
+ key_body = public_key.public_bytes(
29
+ encoding=serialization.Encoding.OpenSSH,
30
+ format=serialization.PublicFormat.OpenSSH,
31
+ ).decode("utf-8")
32
+ # Re-add the participant_id as a comment to conform to the 3-part format.
33
+ return f"{key_body} {participant_id}"
34
+ except Exception as e:
35
+ log.error(f"Could not convert PEM key for {participant_id}: {e}")
36
+ return None
37
+ # If the key format is unknown, log an error and skip it.
38
+ log.error(f"Unknown public key format for participant {participant_id}. Skipping.")
39
+ return None
40
+
41
+
42
  def rebuild_authorized_keys_csv(
43
  key_dir: str, authorized_participants: List[Tuple[str, str]]
44
  ):
45
  """
46
+ Overwrites the public key file with a fresh list from a trusted source,
47
+ using the specific single-line, comma-separated format expected by Flower.
 
 
 
 
 
 
 
48
  """
49
  csv_path = os.path.join(key_dir, "authorized_supernodes.csv")
50
+ log.info(f"Rebuilding authorized keys file at: {csv_path}")
51
 
52
+ # Sanitize each key before adding it to the list.
53
+ public_keys = [
54
+ sanitized_key
55
+ for p_id, key_string in authorized_participants
56
+ if (sanitized_key := _sanitize_key(p_id, key_string)) is not None
57
+ ]
58
 
59
+ # Join all valid public keys into a single comma-separated string.
60
+ content = ",".join(public_keys)
61
+
62
+ # Write the single line to the file, followed by a newline.
63
+ with open(csv_path, "w") as f:
64
+ f.write(content + "\n")
65
+
66
+ log.info(f"Successfully rebuilt {csv_path} with {len(public_keys)} keys.")
67
 
68
 
69
  class AuthKeyGenerator:
 
73
  """
74
 
75
  def __init__(self, key_dir: str = "keys"):
 
 
 
 
 
 
76
  self.key_dir = key_dir
77
  os.makedirs(self.key_dir, exist_ok=True)
78
  log.info(f"Authentication key directory set to: {self.key_dir}")
 
94
  encryption_algorithm=serialization.NoEncryption(),
95
  )
96
  )
 
97
  os.chmod(priv_key_path, 0o600)
98
  log.info(f"Private key for {participant_id} saved securely to {priv_key_path}")
99
  return priv_key_path
100
 
101
  def _save_public_key_file(
102
+ self, public_key_ssh_string: str, participant_id: str
103
  ) -> str:
104
+ """Saves the full OpenSSH public key string to a .pub file."""
 
105
  pub_key_path = os.path.join(self.key_dir, f"{participant_id}.pub")
106
+ with open(pub_key_path, "w") as f:
107
+ f.write(public_key_ssh_string)
 
 
 
 
 
108
  log.info(f"Public key for {participant_id} saved to {pub_key_path}")
109
  return pub_key_path
110
 
 
112
  """
113
  Generates and saves a new EC key pair for a participant.
114
 
 
 
 
 
 
115
  Returns:
116
  A tuple containing:
117
  - The file path to the generated private key.
118
  - The file path to the generated public key.
119
+ - The public key as a single-line OpenSSH string with a comment.
120
  """
121
  private_key = self._generate_key_pair()
122
  public_key = private_key.public_key()
123
 
124
+ # Generate the base OpenSSH key string (type and key data)
125
+ key_body = public_key.public_bytes(
126
+ encoding=serialization.Encoding.OpenSSH,
127
+ format=serialization.PublicFormat.OpenSSH,
 
 
128
  ).decode("utf-8")
129
 
130
+ # Append the participant_id as a comment, creating the full 3-part key.
131
+ public_key_ssh_string = f"{key_body} {participant_id}"
132
+
133
+ priv_key_path = self._save_private_key(private_key, participant_id)
134
+ pub_key_path = self._save_public_key_file(public_key_ssh_string, participant_id)
135
+
136
+ return (priv_key_path, pub_key_path, public_key_ssh_string)
blossomtune_gradio/blossomfile.py CHANGED
@@ -39,7 +39,7 @@ def create_blossomfile(
39
  The full path to the generated .blossomfile.
40
  """
41
  os.makedirs(output_dir, exist_ok=True)
42
- blossomfile_path = os.path.join(output_dir, f"{participant_id}.blossomfile")
43
  log.info(f"Creating Blossomfile for {participant_id} at {blossomfile_path}")
44
 
45
  # 1. Create the blossom.json configuration data
 
39
  The full path to the generated .blossomfile.
40
  """
41
  os.makedirs(output_dir, exist_ok=True)
42
+ blossomfile_path = os.path.join(output_dir, f"{participant_id}-blossomfile.zip")
43
  log.info(f"Creating Blossomfile for {participant_id} at {blossomfile_path}")
44
 
45
  # 1. Create the blossom.json configuration data
blossomtune_gradio/config.py CHANGED
@@ -22,7 +22,7 @@ SMTP_REQUIRE_TLS = util.strtobool(os.getenv("SMTP_REQUIRE_TLS", "false"))
22
  SMTP_USER = os.getenv("SMTP_USER", "")
23
  SMTP_PASSWORD = os.getenv("SMTP_PASSWORD", "")
24
  EMAIL_PROVIDER = os.getenv("EMAIL_PROVIDER", "smtp")
25
- SUPERLINK_HOST = os.getenv("SUPERLINK_HOST", "127.0.0.1:9092")
26
  SUPERLINK_PORT = int(os.getenv("SUPERLINK_PORT", 9092))
27
  SUPERLINK_CONTROL_API_PORT = int(os.getenv("SUPERLINK_CONTROL_API_PORT", 9093))
28
  SUPERLINK_MODE = os.getenv("SUPERLINK_MODE", "internal").lower() # Or external
@@ -50,6 +50,20 @@ BLOSSOMTUNE_TLS_CA_CERTFILE = os.path.join(BLOSSOMTUNE_TLS_CERT_PATH, "ca.crt")
50
  BLOSSOMTUNE_TLS_CERTFILE = os.path.join(BLOSSOMTUNE_TLS_CERT_PATH, "server.pem")
51
  BLOSSOMTUNE_TLS_KEYFILE = os.path.join(BLOSSOMTUNE_TLS_CERT_PATH, "server.key")
52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  # Flower Apps
54
  FLOWER_APPS = os.getenv("FLOWER_APPS", ["flower_apps.quickstart_huggingface"])
55
  FLOWER_APPS = (
 
22
  SMTP_USER = os.getenv("SMTP_USER", "")
23
  SMTP_PASSWORD = os.getenv("SMTP_PASSWORD", "")
24
  EMAIL_PROVIDER = os.getenv("EMAIL_PROVIDER", "smtp")
25
+ SUPERLINK_HOST = os.getenv("SUPERLINK_HOST", "127.0.0.1")
26
  SUPERLINK_PORT = int(os.getenv("SUPERLINK_PORT", 9092))
27
  SUPERLINK_CONTROL_API_PORT = int(os.getenv("SUPERLINK_CONTROL_API_PORT", 9093))
28
  SUPERLINK_MODE = os.getenv("SUPERLINK_MODE", "internal").lower() # Or external
 
50
  BLOSSOMTUNE_TLS_CERTFILE = os.path.join(BLOSSOMTUNE_TLS_CERT_PATH, "server.pem")
51
  BLOSSOMTUNE_TLS_KEYFILE = os.path.join(BLOSSOMTUNE_TLS_CERT_PATH, "server.key")
52
 
53
+ # EC Auth - Keys
54
+ AUTH_KEYS_DIR = os.getenv(
55
+ "AUTH_KEYS_DIR",
56
+ "/data/keys/"
57
+ if os.path.isdir("/data/keys")
58
+ else os.path.join(PROJECT_PATH, "./data/keys/"),
59
+ )
60
+
61
+ # EC Auth - CSV File
62
+ AUTH_KEYS_CSV_PATH = os.getenv(
63
+ "AUTH_KEYS_CSV_PATH",
64
+ os.path.join(AUTH_KEYS_DIR, "authorized_supernodes.csv"),
65
+ )
66
+
67
  # Flower Apps
68
  FLOWER_APPS = os.getenv("FLOWER_APPS", ["flower_apps.quickstart_huggingface"])
69
  FLOWER_APPS = (
blossomtune_gradio/database.py CHANGED
@@ -29,6 +29,7 @@ class Request(Base):
29
  hf_handle = Column(String, nullable=True)
30
  activation_code = Column(String, nullable=True)
31
  is_activated = Column(Integer, nullable=False, default=0)
 
32
 
33
  def __repr__(self):
34
  return (
 
29
  hf_handle = Column(String, nullable=True)
30
  activation_code = Column(String, nullable=True)
31
  is_activated = Column(Integer, nullable=False, default=0)
32
+ public_key_pem = Column(String(), nullable=True)
33
 
34
  def __repr__(self):
35
  return (
blossomtune_gradio/federation.py CHANGED
@@ -1,11 +1,15 @@
 
1
  import string
2
  import secrets
 
3
 
4
  from blossomtune_gradio import config as cfg
5
  from blossomtune_gradio import mail
6
  from blossomtune_gradio import util
7
  from blossomtune_gradio.settings import settings
8
  from blossomtune_gradio.database import SessionLocal, Request, Config
 
 
9
 
10
 
11
  def generate_participant_id(length=6):
@@ -24,7 +28,6 @@ def check_participant_status(pid_to_check: str, email: str, activation_code: str
24
  """
25
  Handles a participant's request to join, activate, or check status using SQLAlchemy.
26
  Returns a tuple: (is_approved: bool, message: str, data: any | None)
27
- The 'is_approved' boolean is True ONLY when the participant's final status is 'approved'.
28
  """
29
  with SessionLocal() as db:
30
  query = db.query(Request).filter(
@@ -40,51 +43,46 @@ def check_participant_status(pid_to_check: str, email: str, activation_code: str
40
  )
41
  num_partitions = num_partitions_config.value if num_partitions_config else "10"
42
 
43
- # Case 1: New user registration
44
- if request is None:
45
- if activation_code:
46
- return (False, settings.get_text("activation_invalid_md"), None)
47
- if not util.validate_email(email):
48
- return (False, settings.get_text("invalid_email_md"), None)
49
-
50
- approved_count = (
51
- db.query(Request).filter(Request.status == "approved").count()
52
- )
53
- if approved_count >= cfg.MAX_NUM_NODES:
54
- return (False, settings.get_text("federation_full_md"), None)
55
-
56
- participant_id = generate_participant_id()
57
- new_activation_code = generate_activation_code()
58
- mail_sent, message = mail.send_activation_email(email, new_activation_code)
59
-
60
- if mail_sent:
61
- new_request = Request(
62
- participant_id=participant_id,
63
- hf_handle=pid_to_check,
64
- email=email,
65
- activation_code=new_activation_code,
66
  )
67
- db.add(new_request)
68
- db.commit()
69
- # A successful registration step, but not yet approved for federation.
70
- return (False, settings.get_text("registration_submitted_md"), None)
71
- else:
72
- return (False, message, None)
73
-
74
- # Case 2: User is activating their account
75
- if not request.is_activated:
76
- if activation_code == request.activation_code:
77
- request.is_activated = 1
78
- db.commit()
79
- # A successful activation step, but not yet approved.
80
- return (False, settings.get_text("activation_successful_md"), None)
81
- else:
82
- return (False, settings.get_text("activation_invalid_md"), None)
83
-
84
- # At this point, user is activated.
85
- # They must provide the activation code to check their final status.
86
- if not activation_code:
87
- return (False, settings.get_text("missing_activation_code_md"), None)
 
 
88
 
89
  # Case 3: Activated user is checking their final status
90
  if request.status == "approved":
@@ -93,17 +91,37 @@ def check_participant_status(pid_to_check: str, email: str, activation_code: str
93
  if not cfg.SPACE_ID
94
  else f"{cfg.SPACE_ID.split('/')[1]}-{cfg.SPACE_ID.split('/')[0]}.hf.space"
95
  )
96
- superlink_hostname = cfg.SUPERLINK_HOST or hostname
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
98
  connection_string = settings.get_text(
99
  "status_approved_md",
100
  participant_id=request.participant_id,
101
  partition_id=request.partition_id,
102
- superlink_hostname=superlink_hostname,
 
103
  num_partitions=num_partitions,
104
  )
105
- # The user is fully approved. Return success and the cert path.
106
- return (True, connection_string, cfg.BLOSSOMTUNE_TLS_CA_CERTFILE)
107
  elif request.status == "pending":
108
  return (False, settings.get_text("status_pending_md"), None)
109
  else: # Denied
@@ -131,7 +149,6 @@ def manage_request(participant_id: str, partition_id: str, action: str):
131
  if action == "approve":
132
  if not partition_id or not partition_id.isdigit():
133
  return False, "Please provide a valid integer for the Partition ID."
134
-
135
  p_id_int = int(partition_id)
136
  if not request.is_activated:
137
  return (
@@ -144,7 +161,6 @@ def manage_request(participant_id: str, partition_id: str, action: str):
144
  .filter(Request.partition_id == p_id_int, Request.status == "approved")
145
  .first()
146
  )
147
-
148
  if existing_participant:
149
  return (
150
  False,
@@ -155,18 +171,48 @@ def manage_request(participant_id: str, partition_id: str, action: str):
155
 
156
  request.status = "approved"
157
  request.partition_id = p_id_int
 
 
 
 
 
 
 
158
  db.commit()
 
 
 
 
 
 
 
 
 
 
 
159
  return (
160
  True,
161
- f"Participant {participant_id} is allowed to join the federation.",
162
  )
163
  else: # Deny
164
  request.status = "denied"
165
  request.partition_id = None
 
166
  db.commit()
 
 
 
 
 
 
 
 
 
 
 
167
  return (
168
  True,
169
- f"Participant {participant_id} is not allowed to join the federation.",
170
  )
171
 
172
 
 
1
+ import os
2
  import string
3
  import secrets
4
+ import tempfile
5
 
6
  from blossomtune_gradio import config as cfg
7
  from blossomtune_gradio import mail
8
  from blossomtune_gradio import util
9
  from blossomtune_gradio.settings import settings
10
  from blossomtune_gradio.database import SessionLocal, Request, Config
11
+ from blossomtune_gradio.auth_keys import AuthKeyGenerator, rebuild_authorized_keys_csv
12
+ from blossomtune_gradio.blossomfile import create_blossomfile
13
 
14
 
15
  def generate_participant_id(length=6):
 
28
  """
29
  Handles a participant's request to join, activate, or check status using SQLAlchemy.
30
  Returns a tuple: (is_approved: bool, message: str, data: any | None)
 
31
  """
32
  with SessionLocal() as db:
33
  query = db.query(Request).filter(
 
43
  )
44
  num_partitions = num_partitions_config.value if num_partitions_config else "10"
45
 
46
+ # Case 1 & 2 are for users not yet approved
47
+ if not request or not request.is_activated or not activation_code:
48
+ if request is None:
49
+ if activation_code:
50
+ return (False, settings.get_text("activation_invalid_md"), None)
51
+ if not util.validate_email(email):
52
+ return (False, settings.get_text("invalid_email_md"), None)
53
+ approved_count = (
54
+ db.query(Request).filter(Request.status == "approved").count()
55
+ )
56
+ if approved_count >= cfg.MAX_NUM_NODES:
57
+ return (False, settings.get_text("federation_full_md"), None)
58
+ participant_id = generate_participant_id()
59
+ new_activation_code = generate_activation_code()
60
+ mail_sent, message = mail.send_activation_email(
61
+ email, new_activation_code
 
 
 
 
 
 
 
62
  )
63
+ if mail_sent:
64
+ new_request = Request(
65
+ participant_id=participant_id,
66
+ hf_handle=pid_to_check,
67
+ email=email,
68
+ activation_code=new_activation_code,
69
+ )
70
+ db.add(new_request)
71
+ db.commit()
72
+ return (False, settings.get_text("registration_submitted_md"), None)
73
+ else:
74
+ return (False, message, None)
75
+
76
+ if not request.is_activated:
77
+ if activation_code == request.activation_code:
78
+ request.is_activated = 1
79
+ db.commit()
80
+ return (False, settings.get_text("activation_successful_md"), None)
81
+ else:
82
+ return (False, settings.get_text("activation_invalid_md"), None)
83
+
84
+ if not activation_code:
85
+ return (False, settings.get_text("missing_activation_code_md"), None)
86
 
87
  # Case 3: Activated user is checking their final status
88
  if request.status == "approved":
 
91
  if not cfg.SPACE_ID
92
  else f"{cfg.SPACE_ID.split('/')[1]}-{cfg.SPACE_ID.split('/')[0]}.hf.space"
93
  )
94
+ superlink_address = f"{cfg.SUPERLINK_HOST or hostname}:{cfg.SUPERLINK_PORT}"
95
+
96
+ # Blossomfile Generation
97
+ blossomfile_tempdir = tempfile.mkdtemp() # TODO: remove tempdirs
98
+ try:
99
+ blossomfile_path = create_blossomfile(
100
+ participant_id=request.participant_id,
101
+ output_dir=blossomfile_tempdir,
102
+ ca_cert_path=cfg.BLOSSOMTUNE_TLS_CA_CERTFILE,
103
+ auth_key_path=os.path.join(
104
+ cfg.AUTH_KEYS_DIR, f"{request.participant_id}.key"
105
+ ),
106
+ auth_pub_path=os.path.join(
107
+ cfg.AUTH_KEYS_DIR, f"{request.participant_id}.pub"
108
+ ),
109
+ superlink_address=superlink_address,
110
+ partition_id=request.partition_id,
111
+ num_partitions=int(num_partitions),
112
+ )
113
+ except FileNotFoundError:
114
+ return (False, "An error occurred.", None)
115
 
116
  connection_string = settings.get_text(
117
  "status_approved_md",
118
  participant_id=request.participant_id,
119
  partition_id=request.partition_id,
120
+ superlink_hostname=superlink_address.split(":")[0],
121
+ superlink_port=superlink_address.split(":")[1],
122
  num_partitions=num_partitions,
123
  )
124
+ return (True, connection_string, blossomfile_path)
 
125
  elif request.status == "pending":
126
  return (False, settings.get_text("status_pending_md"), None)
127
  else: # Denied
 
149
  if action == "approve":
150
  if not partition_id or not partition_id.isdigit():
151
  return False, "Please provide a valid integer for the Partition ID."
 
152
  p_id_int = int(partition_id)
153
  if not request.is_activated:
154
  return (
 
161
  .filter(Request.partition_id == p_id_int, Request.status == "approved")
162
  .first()
163
  )
 
164
  if existing_participant:
165
  return (
166
  False,
 
171
 
172
  request.status = "approved"
173
  request.partition_id = p_id_int
174
+
175
+ # Generate and Store Auth Keys
176
+ key_generator = AuthKeyGenerator(key_dir=cfg.AUTH_KEYS_DIR)
177
+ _, _, public_key_pem = key_generator.generate_participant_keys(
178
+ participant_id
179
+ )
180
+ request.public_key_pem = public_key_pem
181
  db.commit()
182
+
183
+ # Rebuild Authorized Keys CSV
184
+ approved_participants = (
185
+ db.query(Request.participant_id, Request.public_key_pem)
186
+ .filter(
187
+ Request.status == "approved", Request.public_key_pem.isnot(None)
188
+ )
189
+ .all()
190
+ )
191
+ rebuild_authorized_keys_csv(cfg.AUTH_KEYS_DIR, approved_participants)
192
+
193
  return (
194
  True,
195
+ f"Participant {participant_id} approved. Keys generated and registry updated.",
196
  )
197
  else: # Deny
198
  request.status = "denied"
199
  request.partition_id = None
200
+ request.public_key_pem = None
201
  db.commit()
202
+
203
+ # --- Rebuild CSV after denial to revoke access ---
204
+ approved_participants = (
205
+ db.query(Request.participant_id, Request.public_key_pem)
206
+ .filter(
207
+ Request.status == "approved", Request.public_key_pem.isnot(None)
208
+ )
209
+ .all()
210
+ )
211
+ rebuild_authorized_keys_csv(cfg.AUTH_KEYS_DIR, approved_participants)
212
+
213
  return (
214
  True,
215
+ f"Participant {participant_id} denied. Their access has been revoked.",
216
  )
217
 
218
 
blossomtune_gradio/processing.py CHANGED
@@ -54,7 +54,9 @@ def start_superlink():
54
  cfg.BLOSSOMTUNE_TLS_CERTFILE,
55
  "--ssl-keyfile",
56
  cfg.BLOSSOMTUNE_TLS_KEYFILE,
57
- ] # Placeholder
 
 
58
  threading.Thread(
59
  target=run_process, args=(command, "superlink"), daemon=True
60
  ).start()
 
54
  cfg.BLOSSOMTUNE_TLS_CERTFILE,
55
  "--ssl-keyfile",
56
  cfg.BLOSSOMTUNE_TLS_KEYFILE,
57
+ "--auth-list-public-keys",
58
+ cfg.AUTH_KEYS_CSV_PATH,
59
+ ]
60
  threading.Thread(
61
  target=run_process, args=(command, "superlink"), daemon=True
62
  ).start()
blossomtune_gradio/settings/blossomtune.schema.json CHANGED
@@ -100,7 +100,7 @@
100
  "status_pending_md",
101
  "status_denied_md",
102
  "participant_not_activated_warning_md",
103
- "partition_in_use_warning",
104
  "auth_status_logged_in_owner_md",
105
  "auth_status_logged_in_user_md",
106
  "auth_status_not_logged_in_md",
 
100
  "status_pending_md",
101
  "status_denied_md",
102
  "participant_not_activated_warning_md",
103
+ "partition_in_use_warning_md",
104
  "auth_status_logged_in_owner_md",
105
  "auth_status_logged_in_user_md",
106
  "auth_status_not_logged_in_md",
blossomtune_gradio/settings/blossomtune.yaml CHANGED
@@ -66,7 +66,7 @@ ui:
66
  participant_not_activated_warning_md: |
67
  This participant has not activated their email yet. Approval is not allowed.
68
 
69
- partition_in_use_warning: |
70
  Partition ID {{ partition_id }} is already assigned. Please choose a different one.
71
 
72
  # --- Authentication Status ---
 
66
  participant_not_activated_warning_md: |
67
  This participant has not activated their email yet. Approval is not allowed.
68
 
69
+ partition_in_use_warning_md: |
70
  Partition ID {{ partition_id }} is already assigned. Please choose a different one.
71
 
72
  # --- Authentication Status ---
docker_entrypoint.sh CHANGED
@@ -1,5 +1,4 @@
1
  #!/bin/bash
2
-
3
  set -e
4
 
5
  echo "Env vars:"
@@ -13,7 +12,8 @@ elif [ "${1}" = "superlink" ]; then
13
  exec flower-superlink \
14
  --ssl-ca-certfile /data/certs/ca.crt \
15
  --ssl-certfile /data/certs/server.pem \
16
- --ssl-keyfile /data/certs/server.key
 
17
  else
18
  exec "$@"
19
  fi
 
1
  #!/bin/bash
 
2
  set -e
3
 
4
  echo "Env vars:"
 
12
  exec flower-superlink \
13
  --ssl-ca-certfile /data/certs/ca.crt \
14
  --ssl-certfile /data/certs/server.pem \
15
+ --ssl-keyfile /data/certs/server.key \
16
+ --auth-list-public-keys /data/keys/authorized_supernodes.csv
17
  else
18
  exec "$@"
19
  fi
pyproject.toml CHANGED
@@ -18,6 +18,9 @@ dependencies = [
18
  "mlx[cpu]>=0.29.2",
19
  "alembic>=1.16.5",
20
  "sqlalchemy>=2.0.43",
 
 
 
21
  ]
22
 
23
  [tool.uv.sources]
@@ -71,8 +74,6 @@ convention = "google" # Accepts: "google", "numpy", or "pep257".
71
 
72
  [dependency-groups]
73
  dev = [
74
- "cryptography>=44.0.3",
75
- "dnspython>=2.8.0",
76
  "pytest>=8.4.1",
77
  "pytest-mock>=3.15.1",
78
  ]
 
18
  "mlx[cpu]>=0.29.2",
19
  "alembic>=1.16.5",
20
  "sqlalchemy>=2.0.43",
21
+ "cryptography>=44.0.3",
22
+ "dnspython>=2.8.0",
23
+ "mlx-lm>=0.28.2",
24
  ]
25
 
26
  [tool.uv.sources]
 
74
 
75
  [dependency-groups]
76
  dev = [
 
 
77
  "pytest>=8.4.1",
78
  "pytest-mock>=3.15.1",
79
  ]
requirements.txt ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ apscheduler>=3.11.0
2
+ flwr[simulation]>=1.21.0
3
+ flwr-datasets>=0.5.0
4
+ gradio[oauth]>=5.44.1
5
+ torch>=2.8.0
6
+ transformers>=4.56.1
7
+ scikit-learn>=1.7.1
8
+ evaluate>=0.4.5
9
+ markupsafe==2.1.3
10
+ jinja2>=3.1.6
11
+ # mlx[cpu]>=0.29.2
12
+ alembic>=1.16.5
13
+ sqlalchemy>=2.0.43
14
+ cryptography>=44.0.3
15
+ dnspython>=2.8.0
16
+ # mlx-lm>=0.28.2
tests/test_auth_keys.py CHANGED
@@ -1,5 +1,4 @@
1
  import os
2
- import csv
3
  import stat
4
  import pytest
5
  from cryptography.hazmat.primitives import serialization
@@ -25,109 +24,117 @@ class TestAuthKeyGenerator:
25
  AuthKeyGenerator(key_dir=str(key_dir))
26
  assert os.path.exists(key_dir)
27
 
28
- def test_generate_participant_keys_creates_files_and_returns_data(
29
  self, key_generator
30
  ):
31
  """
32
- Verify that the main method generates all expected files and returns
33
- the correct data tuple.
34
  """
35
  participant_id = "participant_01"
36
- priv_path, pub_path, pub_pem = key_generator.generate_participant_keys(
37
  participant_id
38
  )
39
 
40
- # 1. Check if files exist
41
  assert os.path.exists(priv_path)
42
  assert os.path.exists(pub_path)
43
- assert priv_path == os.path.join(key_generator.key_dir, f"{participant_id}.key")
44
- assert pub_path == os.path.join(key_generator.key_dir, f"{participant_id}.pub")
45
-
46
- # 2. Check private key file permissions (security check)
47
- # In non-Windows environments, check for 600 permissions.
48
  if os.name != "nt":
49
  file_mode = stat.S_IMODE(os.stat(priv_path).st_mode)
50
  assert file_mode == 0o600
51
 
52
- # 3. Verify key formats and consistency
53
- with open(priv_path, "rb") as f:
54
- private_key = serialization.load_pem_private_key(f.read(), password=None)
55
- with open(pub_path, "rb") as f:
56
- public_key = serialization.load_pem_public_key(f.read())
57
- _ = f.read() # Read again to get bytes
58
-
59
- assert isinstance(private_key, ec.EllipticCurvePrivateKey)
60
- assert isinstance(public_key, ec.EllipticCurvePublicKey)
61
 
62
- # Check that the returned PEM string matches the public key
63
- generated_public_key = private_key.public_key()
64
- pem_from_private = generated_public_key.public_bytes(
65
- encoding=serialization.Encoding.PEM,
66
- format=serialization.PublicFormat.SubjectPublicKeyInfo,
67
- ).decode("utf-8")
68
- assert pub_pem == pem_from_private
69
 
70
 
71
- class TestRebuildCSV:
72
  """Test suite for the rebuild_authorized_keys_csv function."""
73
 
74
- def test_rebuild_csv_creates_file_with_header_for_empty_list(self, tmp_path):
75
- """Verify a CSV with only a header is created for an empty participant list."""
76
- key_dir = tmp_path / "csv_test"
77
  os.makedirs(key_dir)
78
- csv_path = os.path.join(key_dir, "authorized_supernodes.csv")
79
 
80
- rebuild_authorized_keys_csv(key_dir, [])
81
-
82
- assert os.path.exists(csv_path)
83
  with open(csv_path, "r") as f:
84
- reader = csv.reader(f)
85
- header = next(reader)
86
- assert header == ["participant_id", "public_key_pem"]
87
- # Check that there are no more rows
88
- with pytest.raises(StopIteration):
89
- next(reader)
90
-
91
- def test_rebuild_csv_writes_correct_data(self, tmp_path):
92
- """Verify the CSV is created with the correct participant data."""
93
- key_dir = tmp_path / "csv_test"
94
  os.makedirs(key_dir)
95
- csv_path = os.path.join(key_dir, "authorized_supernodes.csv")
96
 
97
  participants = [
98
- ("p1", "---BEGIN PUBLIC KEY---...p1...---END PUBLIC KEY---"),
99
- ("p2", "---BEGIN PUBLIC KEY---...p2...---END PUBLIC KEY---"),
100
  ]
101
- rebuild_authorized_keys_csv(key_dir, participants)
102
 
 
103
  with open(csv_path, "r") as f:
104
- reader = csv.reader(f)
105
- header = next(reader)
106
- row1 = next(reader)
107
- row2 = next(reader)
108
- assert header == ["participant_id", "public_key_pem"]
109
- assert row1 == list(participants[0])
110
- assert row2 == list(participants[1])
111
-
112
- def test_rebuild_csv_overwrites_existing_file(self, tmp_path):
113
- """Verify that an existing CSV file is correctly overwritten."""
114
- key_dir = tmp_path / "csv_test"
115
  os.makedirs(key_dir)
 
 
 
 
 
 
 
 
 
 
 
116
  csv_path = os.path.join(key_dir, "authorized_supernodes.csv")
 
 
117
 
118
- # First run with initial data
119
- initial_participants = [("old_p1", "old_key_1")]
120
- rebuild_authorized_keys_csv(key_dir, initial_participants)
 
121
 
122
- # Second run with new data
123
- new_participants = [("new_p1", "new_key_1"), ("new_p2", "new_key_2")]
124
- rebuild_authorized_keys_csv(key_dir, new_participants)
 
 
 
 
125
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  with open(csv_path, "r") as f:
127
- reader = csv.reader(f)
128
- _ = next(reader)
129
- rows = list(reader)
130
 
131
- assert len(rows) == 2
132
- assert rows[0] == list(new_participants[0])
133
- assert rows[1] == list(new_participants[1])
 
1
  import os
 
2
  import stat
3
  import pytest
4
  from cryptography.hazmat.primitives import serialization
 
24
  AuthKeyGenerator(key_dir=str(key_dir))
25
  assert os.path.exists(key_dir)
26
 
27
+ def test_generate_participant_keys_creates_files_and_returns_openssh_with_comment(
28
  self, key_generator
29
  ):
30
  """
31
+ Verify the main method generates files and returns the public key
32
+ in the correct OpenSSH format including a comment.
33
  """
34
  participant_id = "participant_01"
35
+ priv_path, pub_path, pub_ssh_string = key_generator.generate_participant_keys(
36
  participant_id
37
  )
38
 
39
+ # 1. Check file existence and permissions
40
  assert os.path.exists(priv_path)
41
  assert os.path.exists(pub_path)
 
 
 
 
 
42
  if os.name != "nt":
43
  file_mode = stat.S_IMODE(os.stat(priv_path).st_mode)
44
  assert file_mode == 0o600
45
 
46
+ # 2. Verify that the returned string has three parts (type, key, comment)
47
+ assert pub_ssh_string.startswith("ecdsa-sha2-nistp384")
48
+ assert pub_ssh_string.endswith(participant_id)
49
+ assert len(pub_ssh_string.split(" ")) == 3
 
 
 
 
 
50
 
51
+ # 3. Verify that the public key file can be loaded as an SSH key
52
+ with open(pub_path, "rb") as f:
53
+ public_key_from_file = serialization.load_ssh_public_key(f.read())
54
+ assert isinstance(public_key_from_file, ec.EllipticCurvePublicKey)
 
 
 
55
 
56
 
57
+ class TestRebuildAuthorizedKeysFile:
58
  """Test suite for the rebuild_authorized_keys_csv function."""
59
 
60
+ def test_rebuild_creates_file_with_only_newline_for_empty_list(self, tmp_path):
61
+ """Verify an empty participant list results in a file with just a newline."""
62
+ key_dir = tmp_path / "keys_test"
63
  os.makedirs(key_dir)
64
+ rebuild_authorized_keys_csv(str(key_dir), [])
65
 
66
+ csv_path = os.path.join(key_dir, "authorized_supernodes.csv")
 
 
67
  with open(csv_path, "r") as f:
68
+ content = f.read()
69
+ assert content == "\n"
70
+
71
+ def test_rebuild_writes_correct_single_line_format(self, tmp_path):
72
+ """Verify the file is created in the single-line, comma-separated format."""
73
+ key_dir = tmp_path / "keys_test"
 
 
 
 
74
  os.makedirs(key_dir)
 
75
 
76
  participants = [
77
+ ("p1", "ecdsa-sha2-nistp384 AAAA...key1 p1"),
78
+ ("p2", "ecdsa-sha2-nistp384 AAAA...key2 p2"),
79
  ]
80
+ rebuild_authorized_keys_csv(str(key_dir), participants)
81
 
82
+ csv_path = os.path.join(key_dir, "authorized_supernodes.csv")
83
  with open(csv_path, "r") as f:
84
+ content = f.read().strip()
85
+
86
+ expected_content = (
87
+ "ecdsa-sha2-nistp384 AAAA...key1 p1,ecdsa-sha2-nistp384 AAAA...key2 p2"
88
+ )
89
+ assert content == expected_content
90
+
91
+ def test_rebuild_overwrites_existing_file(self, tmp_path):
92
+ """Verify that an existing file is correctly overwritten."""
93
+ key_dir = tmp_path / "keys_test"
 
94
  os.makedirs(key_dir)
95
+
96
+ # Use dummy data that matches the expected OpenSSH format
97
+ initial_participants = [("old_p1", "ecdsa-sha2-nistp384 old_key_1 old_p1")]
98
+ rebuild_authorized_keys_csv(str(key_dir), initial_participants)
99
+
100
+ new_participants = [
101
+ ("new_p1", "ecdsa-sha2-nistp384 new_key_1 new_p1"),
102
+ ("new_p2", "ecdsa-sha2-nistp384 new_key_2 new_p2"),
103
+ ]
104
+ rebuild_authorized_keys_csv(str(key_dir), new_participants)
105
+
106
  csv_path = os.path.join(key_dir, "authorized_supernodes.csv")
107
+ with open(csv_path, "r") as f:
108
+ content = f.read().strip()
109
 
110
+ expected_content = (
111
+ "ecdsa-sha2-nistp384 new_key_1 new_p1,ecdsa-sha2-nistp384 new_key_2 new_p2"
112
+ )
113
+ assert content == expected_content
114
 
115
+ def test_rebuild_sanitizes_pem_keys_to_ssh_format(self, tmp_path):
116
+ """
117
+ Tests the self-healing capability of the rebuild function to convert
118
+ old PEM keys from the database into the correct OpenSSH format.
119
+ """
120
+ key_dir = tmp_path / "keys_test"
121
+ os.makedirs(key_dir)
122
 
123
+ # Generate a real key pair to get a valid PEM string
124
+ private_key = ec.generate_private_key(ec.SECP384R1())
125
+ public_key = private_key.public_key()
126
+ pem_key = public_key.public_bytes(
127
+ encoding=serialization.Encoding.PEM,
128
+ format=serialization.PublicFormat.SubjectPublicKeyInfo,
129
+ ).decode("utf-8")
130
+
131
+ participants = [("p1_pem", pem_key)]
132
+ rebuild_authorized_keys_csv(str(key_dir), participants)
133
+
134
+ csv_path = os.path.join(key_dir, "authorized_supernodes.csv")
135
  with open(csv_path, "r") as f:
136
+ content = f.read().strip()
 
 
137
 
138
+ # Verify the output is now in the correct OpenSSH format with the comment
139
+ assert content.startswith("ecdsa-sha2-nistp384")
140
+ assert content.endswith("p1_pem")
tests/test_blossomfile.py CHANGED
@@ -50,7 +50,7 @@ def test_create_blossomfile_success(tmp_path, dummy_credential_files):
50
 
51
  # 1. Verify the file was created at the correct path
52
  assert os.path.exists(blossomfile_path)
53
- assert blossomfile_path == str(output_dir / f"{participant_id}.blossomfile")
54
 
55
  # 2. Verify the contents of the zip archive
56
  with zipfile.ZipFile(blossomfile_path, "r") as zf:
 
50
 
51
  # 1. Verify the file was created at the correct path
52
  assert os.path.exists(blossomfile_path)
53
+ assert blossomfile_path == str(output_dir / f"{participant_id}-blossomfile.zip")
54
 
55
  # 2. Verify the contents of the zip archive
56
  with zipfile.ZipFile(blossomfile_path, "r") as zf:
tests/test_federation.py CHANGED
@@ -34,7 +34,7 @@ class TestCheckParticipantStatus:
34
  """Verify successful registration for a new user."""
35
  mock_mail.return_value = (True, "")
36
  approved, message, download = fed.check_participant_status(
37
- "new_user", "new@example.com", ""
38
  )
39
  assert approved is False
40
  assert download is None
@@ -43,7 +43,7 @@ class TestCheckParticipantStatus:
43
  # Verify the user was added to the database
44
  request = db_session.query(Request).filter_by(hf_handle="new_user").first()
45
  assert request is not None
46
- assert request.email == "new@example.com"
47
 
48
  def test_new_user_invalid_email(self, db_session, mock_settings):
49
  """Verify registration fails with an invalid email."""
@@ -108,7 +108,7 @@ class TestCheckParticipantStatus:
108
  assert download is None
109
  assert message == "mock_activation_invalid_md"
110
 
111
- def test_status_check_approved(self, db_session, mock_settings):
112
  """Verify the status check for an approved user."""
113
  approved_user = Request(
114
  participant_id="PID456",
@@ -125,9 +125,9 @@ class TestCheckParticipantStatus:
125
  approved, message, download = fed.check_participant_status(
126
  "approved_user", "approved@example.com", "GHIJKL"
127
  )
128
- assert approved is True
129
- assert download is not None
130
- assert "mock_status_approved_md" in message
131
 
132
 
133
  class TestManageRequest:
@@ -147,7 +147,10 @@ class TestManageRequest:
147
 
148
  success, message = fed.manage_request("PENDING1", "10", "approve")
149
  assert success is True
150
- assert "is allowed to join" in message
 
 
 
151
 
152
  # Verify status in DB
153
  updated_user = (
@@ -182,7 +185,7 @@ class TestManageRequest:
182
  db_session.commit()
183
  success, message = fed.manage_request("PENDING3", "", "deny")
184
  assert success is True
185
- assert "is not allowed to join" in message
186
 
187
  # Verify status in DB
188
  updated_user = (
 
34
  """Verify successful registration for a new user."""
35
  mock_mail.return_value = (True, "")
36
  approved, message, download = fed.check_participant_status(
37
+ "new_user", "hello@ethicalabs.ai", ""
38
  )
39
  assert approved is False
40
  assert download is None
 
43
  # Verify the user was added to the database
44
  request = db_session.query(Request).filter_by(hf_handle="new_user").first()
45
  assert request is not None
46
+ assert request.email == "hello@ethicalabs.ai"
47
 
48
  def test_new_user_invalid_email(self, db_session, mock_settings):
49
  """Verify registration fails with an invalid email."""
 
108
  assert download is None
109
  assert message == "mock_activation_invalid_md"
110
 
111
+ def test_status_check_approved_unmanaged(self, db_session, mock_settings):
112
  """Verify the status check for an approved user."""
113
  approved_user = Request(
114
  participant_id="PID456",
 
125
  approved, message, download = fed.check_participant_status(
126
  "approved_user", "approved@example.com", "GHIJKL"
127
  )
128
+ assert approved is False
129
+ assert download is None # fed.manage_request is not used. download is None.
130
+ assert "An error occurred" in message
131
 
132
 
133
  class TestManageRequest:
 
147
 
148
  success, message = fed.manage_request("PENDING1", "10", "approve")
149
  assert success is True
150
+ assert (
151
+ "Participant PENDING1 approved. Keys generated and registry updated."
152
+ in message
153
+ )
154
 
155
  # Verify status in DB
156
  updated_user = (
 
185
  db_session.commit()
186
  success, message = fed.manage_request("PENDING3", "", "deny")
187
  assert success is True
188
+ assert "Participant PENDING3 denied. Their access has been revoked." in message
189
 
190
  # Verify status in DB
191
  updated_user = (
uv.lock CHANGED
The diff for this file is too large to render. See raw diff