Spaces:
Sleeping
Sleeping
Add Elliptic Curve (EC) Authentication (#21)
Browse files* Adds a single public_key_pem column to the database.
* Add --auth-list-public-keys to docker entrypoint's superlink cmd
* Add --auth-list-public-keys to superlink subprocess cmd
* Add blossomfile generation
* Handle custom csv file format support
* Convert PEM-formatted key to ssh format
* Fix tests and dependencies
* Fix test_rebuild_overwrites_existing_file
* Add requirements.txt for custom HF space install
- alembic/versions/e7169fe29ea1_add_public_key_pem_to_request_table.py +33 -0
- blossomtune_gradio/auth_keys.py +65 -49
- blossomtune_gradio/blossomfile.py +1 -1
- blossomtune_gradio/config.py +15 -1
- blossomtune_gradio/database.py +1 -0
- blossomtune_gradio/federation.py +99 -53
- blossomtune_gradio/processing.py +3 -1
- blossomtune_gradio/settings/blossomtune.schema.json +1 -1
- blossomtune_gradio/settings/blossomtune.yaml +1 -1
- docker_entrypoint.sh +2 -2
- pyproject.toml +3 -2
- requirements.txt +16 -0
- tests/test_auth_keys.py +79 -72
- tests/test_blossomfile.py +1 -1
- tests/test_federation.py +11 -8
- uv.lock +0 -0
alembic/versions/e7169fe29ea1_add_public_key_pem_to_request_table.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Add public_key_pem to Request table.
|
| 2 |
+
|
| 3 |
+
Revision ID: e7169fe29ea1
|
| 4 |
+
Revises: 0b01b44ab005
|
| 5 |
+
Create Date: 2025-10-09 13:23:30.141015
|
| 6 |
+
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from typing import Sequence, Union
|
| 10 |
+
|
| 11 |
+
from alembic import op
|
| 12 |
+
import sqlalchemy as sa
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
# revision identifiers, used by Alembic.
|
| 16 |
+
revision: str = "e7169fe29ea1"
|
| 17 |
+
down_revision: Union[str, Sequence[str], None] = "0b01b44ab005"
|
| 18 |
+
branch_labels: Union[str, Sequence[str], None] = None
|
| 19 |
+
depends_on: Union[str, Sequence[str], None] = None
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def upgrade() -> None:
|
| 23 |
+
"""Upgrade schema."""
|
| 24 |
+
# ### commands auto generated by Alembic - please adjust! ###
|
| 25 |
+
op.add_column("requests", sa.Column("public_key_pem", sa.String(), nullable=True))
|
| 26 |
+
# ### end Alembic commands ###
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def downgrade() -> None:
|
| 30 |
+
"""Downgrade schema."""
|
| 31 |
+
# ### commands auto generated by Alembic - please adjust! ###
|
| 32 |
+
op.drop_column("requests", "public_key_pem")
|
| 33 |
+
# ### end Alembic commands ###
|
blossomtune_gradio/auth_keys.py
CHANGED
|
@@ -1,39 +1,69 @@
|
|
| 1 |
import os
|
| 2 |
-
import csv
|
| 3 |
import logging
|
| 4 |
from typing import List, Tuple
|
| 5 |
from cryptography.hazmat.primitives.asymmetric import ec
|
| 6 |
from cryptography.hazmat.primitives import serialization
|
| 7 |
|
|
|
|
| 8 |
log = logging.getLogger(__name__)
|
| 9 |
|
| 10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
def rebuild_authorized_keys_csv(
|
| 12 |
key_dir: str, authorized_participants: List[Tuple[str, str]]
|
| 13 |
):
|
| 14 |
"""
|
| 15 |
-
Overwrites the public key
|
| 16 |
-
|
| 17 |
-
This function should be called before starting the Flower SuperLink to ensure
|
| 18 |
-
the list of authorized nodes is always in sync with the database.
|
| 19 |
-
|
| 20 |
-
Args:
|
| 21 |
-
key_dir (str): The directory where the CSV will be stored.
|
| 22 |
-
authorized_participants (List[Tuple[str, str]]): A list of tuples,
|
| 23 |
-
where each tuple contains (participant_id, public_key_pem).
|
| 24 |
"""
|
| 25 |
csv_path = os.path.join(key_dir, "authorized_supernodes.csv")
|
| 26 |
-
log.info(f"Rebuilding authorized keys
|
| 27 |
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
for
|
| 32 |
-
|
|
|
|
| 33 |
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
|
| 38 |
|
| 39 |
class AuthKeyGenerator:
|
|
@@ -43,12 +73,6 @@ class AuthKeyGenerator:
|
|
| 43 |
"""
|
| 44 |
|
| 45 |
def __init__(self, key_dir: str = "keys"):
|
| 46 |
-
"""
|
| 47 |
-
Initializes the generator and ensures the key directory exists.
|
| 48 |
-
|
| 49 |
-
Args:
|
| 50 |
-
key_dir (str): The directory where key files will be stored.
|
| 51 |
-
"""
|
| 52 |
self.key_dir = key_dir
|
| 53 |
os.makedirs(self.key_dir, exist_ok=True)
|
| 54 |
log.info(f"Authentication key directory set to: {self.key_dir}")
|
|
@@ -70,24 +94,17 @@ class AuthKeyGenerator:
|
|
| 70 |
encryption_algorithm=serialization.NoEncryption(),
|
| 71 |
)
|
| 72 |
)
|
| 73 |
-
# Set file permissions to read/write for owner only (600)
|
| 74 |
os.chmod(priv_key_path, 0o600)
|
| 75 |
log.info(f"Private key for {participant_id} saved securely to {priv_key_path}")
|
| 76 |
return priv_key_path
|
| 77 |
|
| 78 |
def _save_public_key_file(
|
| 79 |
-
self,
|
| 80 |
) -> str:
|
| 81 |
-
"""Saves the public key to a .pub file."""
|
| 82 |
-
public_key = private_key.public_key()
|
| 83 |
pub_key_path = os.path.join(self.key_dir, f"{participant_id}.pub")
|
| 84 |
-
with open(pub_key_path, "
|
| 85 |
-
f.write(
|
| 86 |
-
public_key.public_bytes(
|
| 87 |
-
encoding=serialization.Encoding.PEM,
|
| 88 |
-
format=serialization.PublicFormat.SubjectPublicKeyInfo,
|
| 89 |
-
)
|
| 90 |
-
)
|
| 91 |
log.info(f"Public key for {participant_id} saved to {pub_key_path}")
|
| 92 |
return pub_key_path
|
| 93 |
|
|
@@ -95,26 +112,25 @@ class AuthKeyGenerator:
|
|
| 95 |
"""
|
| 96 |
Generates and saves a new EC key pair for a participant.
|
| 97 |
|
| 98 |
-
This is the main public method to call when a new participant is approved.
|
| 99 |
-
|
| 100 |
-
Args:
|
| 101 |
-
participant_id (str): A unique identifier for the participant.
|
| 102 |
-
|
| 103 |
Returns:
|
| 104 |
A tuple containing:
|
| 105 |
- The file path to the generated private key.
|
| 106 |
- The file path to the generated public key.
|
| 107 |
-
- The public key as a
|
| 108 |
"""
|
| 109 |
private_key = self._generate_key_pair()
|
| 110 |
public_key = private_key.public_key()
|
| 111 |
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
encoding=serialization.Encoding.PEM,
|
| 117 |
-
format=serialization.PublicFormat.SubjectPublicKeyInfo,
|
| 118 |
).decode("utf-8")
|
| 119 |
|
| 120 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import os
|
|
|
|
| 2 |
import logging
|
| 3 |
from typing import List, Tuple
|
| 4 |
from cryptography.hazmat.primitives.asymmetric import ec
|
| 5 |
from cryptography.hazmat.primitives import serialization
|
| 6 |
|
| 7 |
+
# Configure logging for the module
|
| 8 |
log = logging.getLogger(__name__)
|
| 9 |
|
| 10 |
|
| 11 |
+
def _sanitize_key(participant_id: str, key_str: str) -> str | None:
|
| 12 |
+
"""
|
| 13 |
+
Inspects a key string and converts it to the required OpenSSH format if necessary.
|
| 14 |
+
This provides resilience against old, PEM-formatted keys in the database.
|
| 15 |
+
"""
|
| 16 |
+
if not key_str:
|
| 17 |
+
return None
|
| 18 |
+
# If the key is already in the correct OpenSSH format, return it as is.
|
| 19 |
+
if key_str.startswith("ecdsa-sha2-nistp384"):
|
| 20 |
+
return key_str
|
| 21 |
+
# If the key is in the old PEM format, attempt to convert it.
|
| 22 |
+
if "-----BEGIN PUBLIC KEY-----" in key_str:
|
| 23 |
+
log.warning(
|
| 24 |
+
f"Found PEM-formatted key for participant {participant_id}. Converting to OpenSSH."
|
| 25 |
+
)
|
| 26 |
+
try:
|
| 27 |
+
public_key = serialization.load_pem_public_key(key_str.encode("utf-8"))
|
| 28 |
+
key_body = public_key.public_bytes(
|
| 29 |
+
encoding=serialization.Encoding.OpenSSH,
|
| 30 |
+
format=serialization.PublicFormat.OpenSSH,
|
| 31 |
+
).decode("utf-8")
|
| 32 |
+
# Re-add the participant_id as a comment to conform to the 3-part format.
|
| 33 |
+
return f"{key_body} {participant_id}"
|
| 34 |
+
except Exception as e:
|
| 35 |
+
log.error(f"Could not convert PEM key for {participant_id}: {e}")
|
| 36 |
+
return None
|
| 37 |
+
# If the key format is unknown, log an error and skip it.
|
| 38 |
+
log.error(f"Unknown public key format for participant {participant_id}. Skipping.")
|
| 39 |
+
return None
|
| 40 |
+
|
| 41 |
+
|
| 42 |
def rebuild_authorized_keys_csv(
|
| 43 |
key_dir: str, authorized_participants: List[Tuple[str, str]]
|
| 44 |
):
|
| 45 |
"""
|
| 46 |
+
Overwrites the public key file with a fresh list from a trusted source,
|
| 47 |
+
using the specific single-line, comma-separated format expected by Flower.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
"""
|
| 49 |
csv_path = os.path.join(key_dir, "authorized_supernodes.csv")
|
| 50 |
+
log.info(f"Rebuilding authorized keys file at: {csv_path}")
|
| 51 |
|
| 52 |
+
# Sanitize each key before adding it to the list.
|
| 53 |
+
public_keys = [
|
| 54 |
+
sanitized_key
|
| 55 |
+
for p_id, key_string in authorized_participants
|
| 56 |
+
if (sanitized_key := _sanitize_key(p_id, key_string)) is not None
|
| 57 |
+
]
|
| 58 |
|
| 59 |
+
# Join all valid public keys into a single comma-separated string.
|
| 60 |
+
content = ",".join(public_keys)
|
| 61 |
+
|
| 62 |
+
# Write the single line to the file, followed by a newline.
|
| 63 |
+
with open(csv_path, "w") as f:
|
| 64 |
+
f.write(content + "\n")
|
| 65 |
+
|
| 66 |
+
log.info(f"Successfully rebuilt {csv_path} with {len(public_keys)} keys.")
|
| 67 |
|
| 68 |
|
| 69 |
class AuthKeyGenerator:
|
|
|
|
| 73 |
"""
|
| 74 |
|
| 75 |
def __init__(self, key_dir: str = "keys"):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
self.key_dir = key_dir
|
| 77 |
os.makedirs(self.key_dir, exist_ok=True)
|
| 78 |
log.info(f"Authentication key directory set to: {self.key_dir}")
|
|
|
|
| 94 |
encryption_algorithm=serialization.NoEncryption(),
|
| 95 |
)
|
| 96 |
)
|
|
|
|
| 97 |
os.chmod(priv_key_path, 0o600)
|
| 98 |
log.info(f"Private key for {participant_id} saved securely to {priv_key_path}")
|
| 99 |
return priv_key_path
|
| 100 |
|
| 101 |
def _save_public_key_file(
|
| 102 |
+
self, public_key_ssh_string: str, participant_id: str
|
| 103 |
) -> str:
|
| 104 |
+
"""Saves the full OpenSSH public key string to a .pub file."""
|
|
|
|
| 105 |
pub_key_path = os.path.join(self.key_dir, f"{participant_id}.pub")
|
| 106 |
+
with open(pub_key_path, "w") as f:
|
| 107 |
+
f.write(public_key_ssh_string)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 108 |
log.info(f"Public key for {participant_id} saved to {pub_key_path}")
|
| 109 |
return pub_key_path
|
| 110 |
|
|
|
|
| 112 |
"""
|
| 113 |
Generates and saves a new EC key pair for a participant.
|
| 114 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
Returns:
|
| 116 |
A tuple containing:
|
| 117 |
- The file path to the generated private key.
|
| 118 |
- The file path to the generated public key.
|
| 119 |
+
- The public key as a single-line OpenSSH string with a comment.
|
| 120 |
"""
|
| 121 |
private_key = self._generate_key_pair()
|
| 122 |
public_key = private_key.public_key()
|
| 123 |
|
| 124 |
+
# Generate the base OpenSSH key string (type and key data)
|
| 125 |
+
key_body = public_key.public_bytes(
|
| 126 |
+
encoding=serialization.Encoding.OpenSSH,
|
| 127 |
+
format=serialization.PublicFormat.OpenSSH,
|
|
|
|
|
|
|
| 128 |
).decode("utf-8")
|
| 129 |
|
| 130 |
+
# Append the participant_id as a comment, creating the full 3-part key.
|
| 131 |
+
public_key_ssh_string = f"{key_body} {participant_id}"
|
| 132 |
+
|
| 133 |
+
priv_key_path = self._save_private_key(private_key, participant_id)
|
| 134 |
+
pub_key_path = self._save_public_key_file(public_key_ssh_string, participant_id)
|
| 135 |
+
|
| 136 |
+
return (priv_key_path, pub_key_path, public_key_ssh_string)
|
blossomtune_gradio/blossomfile.py
CHANGED
|
@@ -39,7 +39,7 @@ def create_blossomfile(
|
|
| 39 |
The full path to the generated .blossomfile.
|
| 40 |
"""
|
| 41 |
os.makedirs(output_dir, exist_ok=True)
|
| 42 |
-
blossomfile_path = os.path.join(output_dir, f"{participant_id}.
|
| 43 |
log.info(f"Creating Blossomfile for {participant_id} at {blossomfile_path}")
|
| 44 |
|
| 45 |
# 1. Create the blossom.json configuration data
|
|
|
|
| 39 |
The full path to the generated .blossomfile.
|
| 40 |
"""
|
| 41 |
os.makedirs(output_dir, exist_ok=True)
|
| 42 |
+
blossomfile_path = os.path.join(output_dir, f"{participant_id}-blossomfile.zip")
|
| 43 |
log.info(f"Creating Blossomfile for {participant_id} at {blossomfile_path}")
|
| 44 |
|
| 45 |
# 1. Create the blossom.json configuration data
|
blossomtune_gradio/config.py
CHANGED
|
@@ -22,7 +22,7 @@ SMTP_REQUIRE_TLS = util.strtobool(os.getenv("SMTP_REQUIRE_TLS", "false"))
|
|
| 22 |
SMTP_USER = os.getenv("SMTP_USER", "")
|
| 23 |
SMTP_PASSWORD = os.getenv("SMTP_PASSWORD", "")
|
| 24 |
EMAIL_PROVIDER = os.getenv("EMAIL_PROVIDER", "smtp")
|
| 25 |
-
SUPERLINK_HOST = os.getenv("SUPERLINK_HOST", "127.0.0.1
|
| 26 |
SUPERLINK_PORT = int(os.getenv("SUPERLINK_PORT", 9092))
|
| 27 |
SUPERLINK_CONTROL_API_PORT = int(os.getenv("SUPERLINK_CONTROL_API_PORT", 9093))
|
| 28 |
SUPERLINK_MODE = os.getenv("SUPERLINK_MODE", "internal").lower() # Or external
|
|
@@ -50,6 +50,20 @@ BLOSSOMTUNE_TLS_CA_CERTFILE = os.path.join(BLOSSOMTUNE_TLS_CERT_PATH, "ca.crt")
|
|
| 50 |
BLOSSOMTUNE_TLS_CERTFILE = os.path.join(BLOSSOMTUNE_TLS_CERT_PATH, "server.pem")
|
| 51 |
BLOSSOMTUNE_TLS_KEYFILE = os.path.join(BLOSSOMTUNE_TLS_CERT_PATH, "server.key")
|
| 52 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
# Flower Apps
|
| 54 |
FLOWER_APPS = os.getenv("FLOWER_APPS", ["flower_apps.quickstart_huggingface"])
|
| 55 |
FLOWER_APPS = (
|
|
|
|
| 22 |
SMTP_USER = os.getenv("SMTP_USER", "")
|
| 23 |
SMTP_PASSWORD = os.getenv("SMTP_PASSWORD", "")
|
| 24 |
EMAIL_PROVIDER = os.getenv("EMAIL_PROVIDER", "smtp")
|
| 25 |
+
SUPERLINK_HOST = os.getenv("SUPERLINK_HOST", "127.0.0.1")
|
| 26 |
SUPERLINK_PORT = int(os.getenv("SUPERLINK_PORT", 9092))
|
| 27 |
SUPERLINK_CONTROL_API_PORT = int(os.getenv("SUPERLINK_CONTROL_API_PORT", 9093))
|
| 28 |
SUPERLINK_MODE = os.getenv("SUPERLINK_MODE", "internal").lower() # Or external
|
|
|
|
| 50 |
BLOSSOMTUNE_TLS_CERTFILE = os.path.join(BLOSSOMTUNE_TLS_CERT_PATH, "server.pem")
|
| 51 |
BLOSSOMTUNE_TLS_KEYFILE = os.path.join(BLOSSOMTUNE_TLS_CERT_PATH, "server.key")
|
| 52 |
|
| 53 |
+
# EC Auth - Keys
|
| 54 |
+
AUTH_KEYS_DIR = os.getenv(
|
| 55 |
+
"AUTH_KEYS_DIR",
|
| 56 |
+
"/data/keys/"
|
| 57 |
+
if os.path.isdir("/data/keys")
|
| 58 |
+
else os.path.join(PROJECT_PATH, "./data/keys/"),
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
# EC Auth - CSV File
|
| 62 |
+
AUTH_KEYS_CSV_PATH = os.getenv(
|
| 63 |
+
"AUTH_KEYS_CSV_PATH",
|
| 64 |
+
os.path.join(AUTH_KEYS_DIR, "authorized_supernodes.csv"),
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
# Flower Apps
|
| 68 |
FLOWER_APPS = os.getenv("FLOWER_APPS", ["flower_apps.quickstart_huggingface"])
|
| 69 |
FLOWER_APPS = (
|
blossomtune_gradio/database.py
CHANGED
|
@@ -29,6 +29,7 @@ class Request(Base):
|
|
| 29 |
hf_handle = Column(String, nullable=True)
|
| 30 |
activation_code = Column(String, nullable=True)
|
| 31 |
is_activated = Column(Integer, nullable=False, default=0)
|
|
|
|
| 32 |
|
| 33 |
def __repr__(self):
|
| 34 |
return (
|
|
|
|
| 29 |
hf_handle = Column(String, nullable=True)
|
| 30 |
activation_code = Column(String, nullable=True)
|
| 31 |
is_activated = Column(Integer, nullable=False, default=0)
|
| 32 |
+
public_key_pem = Column(String(), nullable=True)
|
| 33 |
|
| 34 |
def __repr__(self):
|
| 35 |
return (
|
blossomtune_gradio/federation.py
CHANGED
|
@@ -1,11 +1,15 @@
|
|
|
|
|
| 1 |
import string
|
| 2 |
import secrets
|
|
|
|
| 3 |
|
| 4 |
from blossomtune_gradio import config as cfg
|
| 5 |
from blossomtune_gradio import mail
|
| 6 |
from blossomtune_gradio import util
|
| 7 |
from blossomtune_gradio.settings import settings
|
| 8 |
from blossomtune_gradio.database import SessionLocal, Request, Config
|
|
|
|
|
|
|
| 9 |
|
| 10 |
|
| 11 |
def generate_participant_id(length=6):
|
|
@@ -24,7 +28,6 @@ def check_participant_status(pid_to_check: str, email: str, activation_code: str
|
|
| 24 |
"""
|
| 25 |
Handles a participant's request to join, activate, or check status using SQLAlchemy.
|
| 26 |
Returns a tuple: (is_approved: bool, message: str, data: any | None)
|
| 27 |
-
The 'is_approved' boolean is True ONLY when the participant's final status is 'approved'.
|
| 28 |
"""
|
| 29 |
with SessionLocal() as db:
|
| 30 |
query = db.query(Request).filter(
|
|
@@ -40,51 +43,46 @@ def check_participant_status(pid_to_check: str, email: str, activation_code: str
|
|
| 40 |
)
|
| 41 |
num_partitions = num_partitions_config.value if num_partitions_config else "10"
|
| 42 |
|
| 43 |
-
# Case 1
|
| 44 |
-
if request
|
| 45 |
-
if
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
if mail_sent:
|
| 61 |
-
new_request = Request(
|
| 62 |
-
participant_id=participant_id,
|
| 63 |
-
hf_handle=pid_to_check,
|
| 64 |
-
email=email,
|
| 65 |
-
activation_code=new_activation_code,
|
| 66 |
)
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
|
|
|
|
|
|
| 88 |
|
| 89 |
# Case 3: Activated user is checking their final status
|
| 90 |
if request.status == "approved":
|
|
@@ -93,17 +91,37 @@ def check_participant_status(pid_to_check: str, email: str, activation_code: str
|
|
| 93 |
if not cfg.SPACE_ID
|
| 94 |
else f"{cfg.SPACE_ID.split('/')[1]}-{cfg.SPACE_ID.split('/')[0]}.hf.space"
|
| 95 |
)
|
| 96 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
|
| 98 |
connection_string = settings.get_text(
|
| 99 |
"status_approved_md",
|
| 100 |
participant_id=request.participant_id,
|
| 101 |
partition_id=request.partition_id,
|
| 102 |
-
superlink_hostname=
|
|
|
|
| 103 |
num_partitions=num_partitions,
|
| 104 |
)
|
| 105 |
-
|
| 106 |
-
return (True, connection_string, cfg.BLOSSOMTUNE_TLS_CA_CERTFILE)
|
| 107 |
elif request.status == "pending":
|
| 108 |
return (False, settings.get_text("status_pending_md"), None)
|
| 109 |
else: # Denied
|
|
@@ -131,7 +149,6 @@ def manage_request(participant_id: str, partition_id: str, action: str):
|
|
| 131 |
if action == "approve":
|
| 132 |
if not partition_id or not partition_id.isdigit():
|
| 133 |
return False, "Please provide a valid integer for the Partition ID."
|
| 134 |
-
|
| 135 |
p_id_int = int(partition_id)
|
| 136 |
if not request.is_activated:
|
| 137 |
return (
|
|
@@ -144,7 +161,6 @@ def manage_request(participant_id: str, partition_id: str, action: str):
|
|
| 144 |
.filter(Request.partition_id == p_id_int, Request.status == "approved")
|
| 145 |
.first()
|
| 146 |
)
|
| 147 |
-
|
| 148 |
if existing_participant:
|
| 149 |
return (
|
| 150 |
False,
|
|
@@ -155,18 +171,48 @@ def manage_request(participant_id: str, partition_id: str, action: str):
|
|
| 155 |
|
| 156 |
request.status = "approved"
|
| 157 |
request.partition_id = p_id_int
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 158 |
db.commit()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 159 |
return (
|
| 160 |
True,
|
| 161 |
-
f"Participant {participant_id}
|
| 162 |
)
|
| 163 |
else: # Deny
|
| 164 |
request.status = "denied"
|
| 165 |
request.partition_id = None
|
|
|
|
| 166 |
db.commit()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 167 |
return (
|
| 168 |
True,
|
| 169 |
-
f"Participant {participant_id}
|
| 170 |
)
|
| 171 |
|
| 172 |
|
|
|
|
| 1 |
+
import os
|
| 2 |
import string
|
| 3 |
import secrets
|
| 4 |
+
import tempfile
|
| 5 |
|
| 6 |
from blossomtune_gradio import config as cfg
|
| 7 |
from blossomtune_gradio import mail
|
| 8 |
from blossomtune_gradio import util
|
| 9 |
from blossomtune_gradio.settings import settings
|
| 10 |
from blossomtune_gradio.database import SessionLocal, Request, Config
|
| 11 |
+
from blossomtune_gradio.auth_keys import AuthKeyGenerator, rebuild_authorized_keys_csv
|
| 12 |
+
from blossomtune_gradio.blossomfile import create_blossomfile
|
| 13 |
|
| 14 |
|
| 15 |
def generate_participant_id(length=6):
|
|
|
|
| 28 |
"""
|
| 29 |
Handles a participant's request to join, activate, or check status using SQLAlchemy.
|
| 30 |
Returns a tuple: (is_approved: bool, message: str, data: any | None)
|
|
|
|
| 31 |
"""
|
| 32 |
with SessionLocal() as db:
|
| 33 |
query = db.query(Request).filter(
|
|
|
|
| 43 |
)
|
| 44 |
num_partitions = num_partitions_config.value if num_partitions_config else "10"
|
| 45 |
|
| 46 |
+
# Case 1 & 2 are for users not yet approved
|
| 47 |
+
if not request or not request.is_activated or not activation_code:
|
| 48 |
+
if request is None:
|
| 49 |
+
if activation_code:
|
| 50 |
+
return (False, settings.get_text("activation_invalid_md"), None)
|
| 51 |
+
if not util.validate_email(email):
|
| 52 |
+
return (False, settings.get_text("invalid_email_md"), None)
|
| 53 |
+
approved_count = (
|
| 54 |
+
db.query(Request).filter(Request.status == "approved").count()
|
| 55 |
+
)
|
| 56 |
+
if approved_count >= cfg.MAX_NUM_NODES:
|
| 57 |
+
return (False, settings.get_text("federation_full_md"), None)
|
| 58 |
+
participant_id = generate_participant_id()
|
| 59 |
+
new_activation_code = generate_activation_code()
|
| 60 |
+
mail_sent, message = mail.send_activation_email(
|
| 61 |
+
email, new_activation_code
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
)
|
| 63 |
+
if mail_sent:
|
| 64 |
+
new_request = Request(
|
| 65 |
+
participant_id=participant_id,
|
| 66 |
+
hf_handle=pid_to_check,
|
| 67 |
+
email=email,
|
| 68 |
+
activation_code=new_activation_code,
|
| 69 |
+
)
|
| 70 |
+
db.add(new_request)
|
| 71 |
+
db.commit()
|
| 72 |
+
return (False, settings.get_text("registration_submitted_md"), None)
|
| 73 |
+
else:
|
| 74 |
+
return (False, message, None)
|
| 75 |
+
|
| 76 |
+
if not request.is_activated:
|
| 77 |
+
if activation_code == request.activation_code:
|
| 78 |
+
request.is_activated = 1
|
| 79 |
+
db.commit()
|
| 80 |
+
return (False, settings.get_text("activation_successful_md"), None)
|
| 81 |
+
else:
|
| 82 |
+
return (False, settings.get_text("activation_invalid_md"), None)
|
| 83 |
+
|
| 84 |
+
if not activation_code:
|
| 85 |
+
return (False, settings.get_text("missing_activation_code_md"), None)
|
| 86 |
|
| 87 |
# Case 3: Activated user is checking their final status
|
| 88 |
if request.status == "approved":
|
|
|
|
| 91 |
if not cfg.SPACE_ID
|
| 92 |
else f"{cfg.SPACE_ID.split('/')[1]}-{cfg.SPACE_ID.split('/')[0]}.hf.space"
|
| 93 |
)
|
| 94 |
+
superlink_address = f"{cfg.SUPERLINK_HOST or hostname}:{cfg.SUPERLINK_PORT}"
|
| 95 |
+
|
| 96 |
+
# Blossomfile Generation
|
| 97 |
+
blossomfile_tempdir = tempfile.mkdtemp() # TODO: remove tempdirs
|
| 98 |
+
try:
|
| 99 |
+
blossomfile_path = create_blossomfile(
|
| 100 |
+
participant_id=request.participant_id,
|
| 101 |
+
output_dir=blossomfile_tempdir,
|
| 102 |
+
ca_cert_path=cfg.BLOSSOMTUNE_TLS_CA_CERTFILE,
|
| 103 |
+
auth_key_path=os.path.join(
|
| 104 |
+
cfg.AUTH_KEYS_DIR, f"{request.participant_id}.key"
|
| 105 |
+
),
|
| 106 |
+
auth_pub_path=os.path.join(
|
| 107 |
+
cfg.AUTH_KEYS_DIR, f"{request.participant_id}.pub"
|
| 108 |
+
),
|
| 109 |
+
superlink_address=superlink_address,
|
| 110 |
+
partition_id=request.partition_id,
|
| 111 |
+
num_partitions=int(num_partitions),
|
| 112 |
+
)
|
| 113 |
+
except FileNotFoundError:
|
| 114 |
+
return (False, "An error occurred.", None)
|
| 115 |
|
| 116 |
connection_string = settings.get_text(
|
| 117 |
"status_approved_md",
|
| 118 |
participant_id=request.participant_id,
|
| 119 |
partition_id=request.partition_id,
|
| 120 |
+
superlink_hostname=superlink_address.split(":")[0],
|
| 121 |
+
superlink_port=superlink_address.split(":")[1],
|
| 122 |
num_partitions=num_partitions,
|
| 123 |
)
|
| 124 |
+
return (True, connection_string, blossomfile_path)
|
|
|
|
| 125 |
elif request.status == "pending":
|
| 126 |
return (False, settings.get_text("status_pending_md"), None)
|
| 127 |
else: # Denied
|
|
|
|
| 149 |
if action == "approve":
|
| 150 |
if not partition_id or not partition_id.isdigit():
|
| 151 |
return False, "Please provide a valid integer for the Partition ID."
|
|
|
|
| 152 |
p_id_int = int(partition_id)
|
| 153 |
if not request.is_activated:
|
| 154 |
return (
|
|
|
|
| 161 |
.filter(Request.partition_id == p_id_int, Request.status == "approved")
|
| 162 |
.first()
|
| 163 |
)
|
|
|
|
| 164 |
if existing_participant:
|
| 165 |
return (
|
| 166 |
False,
|
|
|
|
| 171 |
|
| 172 |
request.status = "approved"
|
| 173 |
request.partition_id = p_id_int
|
| 174 |
+
|
| 175 |
+
# Generate and Store Auth Keys
|
| 176 |
+
key_generator = AuthKeyGenerator(key_dir=cfg.AUTH_KEYS_DIR)
|
| 177 |
+
_, _, public_key_pem = key_generator.generate_participant_keys(
|
| 178 |
+
participant_id
|
| 179 |
+
)
|
| 180 |
+
request.public_key_pem = public_key_pem
|
| 181 |
db.commit()
|
| 182 |
+
|
| 183 |
+
# Rebuild Authorized Keys CSV
|
| 184 |
+
approved_participants = (
|
| 185 |
+
db.query(Request.participant_id, Request.public_key_pem)
|
| 186 |
+
.filter(
|
| 187 |
+
Request.status == "approved", Request.public_key_pem.isnot(None)
|
| 188 |
+
)
|
| 189 |
+
.all()
|
| 190 |
+
)
|
| 191 |
+
rebuild_authorized_keys_csv(cfg.AUTH_KEYS_DIR, approved_participants)
|
| 192 |
+
|
| 193 |
return (
|
| 194 |
True,
|
| 195 |
+
f"Participant {participant_id} approved. Keys generated and registry updated.",
|
| 196 |
)
|
| 197 |
else: # Deny
|
| 198 |
request.status = "denied"
|
| 199 |
request.partition_id = None
|
| 200 |
+
request.public_key_pem = None
|
| 201 |
db.commit()
|
| 202 |
+
|
| 203 |
+
# --- Rebuild CSV after denial to revoke access ---
|
| 204 |
+
approved_participants = (
|
| 205 |
+
db.query(Request.participant_id, Request.public_key_pem)
|
| 206 |
+
.filter(
|
| 207 |
+
Request.status == "approved", Request.public_key_pem.isnot(None)
|
| 208 |
+
)
|
| 209 |
+
.all()
|
| 210 |
+
)
|
| 211 |
+
rebuild_authorized_keys_csv(cfg.AUTH_KEYS_DIR, approved_participants)
|
| 212 |
+
|
| 213 |
return (
|
| 214 |
True,
|
| 215 |
+
f"Participant {participant_id} denied. Their access has been revoked.",
|
| 216 |
)
|
| 217 |
|
| 218 |
|
blossomtune_gradio/processing.py
CHANGED
|
@@ -54,7 +54,9 @@ def start_superlink():
|
|
| 54 |
cfg.BLOSSOMTUNE_TLS_CERTFILE,
|
| 55 |
"--ssl-keyfile",
|
| 56 |
cfg.BLOSSOMTUNE_TLS_KEYFILE,
|
| 57 |
-
|
|
|
|
|
|
|
| 58 |
threading.Thread(
|
| 59 |
target=run_process, args=(command, "superlink"), daemon=True
|
| 60 |
).start()
|
|
|
|
| 54 |
cfg.BLOSSOMTUNE_TLS_CERTFILE,
|
| 55 |
"--ssl-keyfile",
|
| 56 |
cfg.BLOSSOMTUNE_TLS_KEYFILE,
|
| 57 |
+
"--auth-list-public-keys",
|
| 58 |
+
cfg.AUTH_KEYS_CSV_PATH,
|
| 59 |
+
]
|
| 60 |
threading.Thread(
|
| 61 |
target=run_process, args=(command, "superlink"), daemon=True
|
| 62 |
).start()
|
blossomtune_gradio/settings/blossomtune.schema.json
CHANGED
|
@@ -100,7 +100,7 @@
|
|
| 100 |
"status_pending_md",
|
| 101 |
"status_denied_md",
|
| 102 |
"participant_not_activated_warning_md",
|
| 103 |
-
"
|
| 104 |
"auth_status_logged_in_owner_md",
|
| 105 |
"auth_status_logged_in_user_md",
|
| 106 |
"auth_status_not_logged_in_md",
|
|
|
|
| 100 |
"status_pending_md",
|
| 101 |
"status_denied_md",
|
| 102 |
"participant_not_activated_warning_md",
|
| 103 |
+
"partition_in_use_warning_md",
|
| 104 |
"auth_status_logged_in_owner_md",
|
| 105 |
"auth_status_logged_in_user_md",
|
| 106 |
"auth_status_not_logged_in_md",
|
blossomtune_gradio/settings/blossomtune.yaml
CHANGED
|
@@ -66,7 +66,7 @@ ui:
|
|
| 66 |
participant_not_activated_warning_md: |
|
| 67 |
This participant has not activated their email yet. Approval is not allowed.
|
| 68 |
|
| 69 |
-
|
| 70 |
Partition ID {{ partition_id }} is already assigned. Please choose a different one.
|
| 71 |
|
| 72 |
# --- Authentication Status ---
|
|
|
|
| 66 |
participant_not_activated_warning_md: |
|
| 67 |
This participant has not activated their email yet. Approval is not allowed.
|
| 68 |
|
| 69 |
+
partition_in_use_warning_md: |
|
| 70 |
Partition ID {{ partition_id }} is already assigned. Please choose a different one.
|
| 71 |
|
| 72 |
# --- Authentication Status ---
|
docker_entrypoint.sh
CHANGED
|
@@ -1,5 +1,4 @@
|
|
| 1 |
#!/bin/bash
|
| 2 |
-
|
| 3 |
set -e
|
| 4 |
|
| 5 |
echo "Env vars:"
|
|
@@ -13,7 +12,8 @@ elif [ "${1}" = "superlink" ]; then
|
|
| 13 |
exec flower-superlink \
|
| 14 |
--ssl-ca-certfile /data/certs/ca.crt \
|
| 15 |
--ssl-certfile /data/certs/server.pem \
|
| 16 |
-
--ssl-keyfile /data/certs/server.key
|
|
|
|
| 17 |
else
|
| 18 |
exec "$@"
|
| 19 |
fi
|
|
|
|
| 1 |
#!/bin/bash
|
|
|
|
| 2 |
set -e
|
| 3 |
|
| 4 |
echo "Env vars:"
|
|
|
|
| 12 |
exec flower-superlink \
|
| 13 |
--ssl-ca-certfile /data/certs/ca.crt \
|
| 14 |
--ssl-certfile /data/certs/server.pem \
|
| 15 |
+
--ssl-keyfile /data/certs/server.key \
|
| 16 |
+
--auth-list-public-keys /data/keys/authorized_supernodes.csv
|
| 17 |
else
|
| 18 |
exec "$@"
|
| 19 |
fi
|
pyproject.toml
CHANGED
|
@@ -18,6 +18,9 @@ dependencies = [
|
|
| 18 |
"mlx[cpu]>=0.29.2",
|
| 19 |
"alembic>=1.16.5",
|
| 20 |
"sqlalchemy>=2.0.43",
|
|
|
|
|
|
|
|
|
|
| 21 |
]
|
| 22 |
|
| 23 |
[tool.uv.sources]
|
|
@@ -71,8 +74,6 @@ convention = "google" # Accepts: "google", "numpy", or "pep257".
|
|
| 71 |
|
| 72 |
[dependency-groups]
|
| 73 |
dev = [
|
| 74 |
-
"cryptography>=44.0.3",
|
| 75 |
-
"dnspython>=2.8.0",
|
| 76 |
"pytest>=8.4.1",
|
| 77 |
"pytest-mock>=3.15.1",
|
| 78 |
]
|
|
|
|
| 18 |
"mlx[cpu]>=0.29.2",
|
| 19 |
"alembic>=1.16.5",
|
| 20 |
"sqlalchemy>=2.0.43",
|
| 21 |
+
"cryptography>=44.0.3",
|
| 22 |
+
"dnspython>=2.8.0",
|
| 23 |
+
"mlx-lm>=0.28.2",
|
| 24 |
]
|
| 25 |
|
| 26 |
[tool.uv.sources]
|
|
|
|
| 74 |
|
| 75 |
[dependency-groups]
|
| 76 |
dev = [
|
|
|
|
|
|
|
| 77 |
"pytest>=8.4.1",
|
| 78 |
"pytest-mock>=3.15.1",
|
| 79 |
]
|
requirements.txt
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
apscheduler>=3.11.0
|
| 2 |
+
flwr[simulation]>=1.21.0
|
| 3 |
+
flwr-datasets>=0.5.0
|
| 4 |
+
gradio[oauth]>=5.44.1
|
| 5 |
+
torch>=2.8.0
|
| 6 |
+
transformers>=4.56.1
|
| 7 |
+
scikit-learn>=1.7.1
|
| 8 |
+
evaluate>=0.4.5
|
| 9 |
+
markupsafe==2.1.3
|
| 10 |
+
jinja2>=3.1.6
|
| 11 |
+
# mlx[cpu]>=0.29.2
|
| 12 |
+
alembic>=1.16.5
|
| 13 |
+
sqlalchemy>=2.0.43
|
| 14 |
+
cryptography>=44.0.3
|
| 15 |
+
dnspython>=2.8.0
|
| 16 |
+
# mlx-lm>=0.28.2
|
tests/test_auth_keys.py
CHANGED
|
@@ -1,5 +1,4 @@
|
|
| 1 |
import os
|
| 2 |
-
import csv
|
| 3 |
import stat
|
| 4 |
import pytest
|
| 5 |
from cryptography.hazmat.primitives import serialization
|
|
@@ -25,109 +24,117 @@ class TestAuthKeyGenerator:
|
|
| 25 |
AuthKeyGenerator(key_dir=str(key_dir))
|
| 26 |
assert os.path.exists(key_dir)
|
| 27 |
|
| 28 |
-
def
|
| 29 |
self, key_generator
|
| 30 |
):
|
| 31 |
"""
|
| 32 |
-
Verify
|
| 33 |
-
the correct
|
| 34 |
"""
|
| 35 |
participant_id = "participant_01"
|
| 36 |
-
priv_path, pub_path,
|
| 37 |
participant_id
|
| 38 |
)
|
| 39 |
|
| 40 |
-
# 1. Check
|
| 41 |
assert os.path.exists(priv_path)
|
| 42 |
assert os.path.exists(pub_path)
|
| 43 |
-
assert priv_path == os.path.join(key_generator.key_dir, f"{participant_id}.key")
|
| 44 |
-
assert pub_path == os.path.join(key_generator.key_dir, f"{participant_id}.pub")
|
| 45 |
-
|
| 46 |
-
# 2. Check private key file permissions (security check)
|
| 47 |
-
# In non-Windows environments, check for 600 permissions.
|
| 48 |
if os.name != "nt":
|
| 49 |
file_mode = stat.S_IMODE(os.stat(priv_path).st_mode)
|
| 50 |
assert file_mode == 0o600
|
| 51 |
|
| 52 |
-
#
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
public_key = serialization.load_pem_public_key(f.read())
|
| 57 |
-
_ = f.read() # Read again to get bytes
|
| 58 |
-
|
| 59 |
-
assert isinstance(private_key, ec.EllipticCurvePrivateKey)
|
| 60 |
-
assert isinstance(public_key, ec.EllipticCurvePublicKey)
|
| 61 |
|
| 62 |
-
#
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
format=serialization.PublicFormat.SubjectPublicKeyInfo,
|
| 67 |
-
).decode("utf-8")
|
| 68 |
-
assert pub_pem == pem_from_private
|
| 69 |
|
| 70 |
|
| 71 |
-
class
|
| 72 |
"""Test suite for the rebuild_authorized_keys_csv function."""
|
| 73 |
|
| 74 |
-
def
|
| 75 |
-
"""Verify
|
| 76 |
-
key_dir = tmp_path / "
|
| 77 |
os.makedirs(key_dir)
|
| 78 |
-
|
| 79 |
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
assert os.path.exists(csv_path)
|
| 83 |
with open(csv_path, "r") as f:
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
def test_rebuild_csv_writes_correct_data(self, tmp_path):
|
| 92 |
-
"""Verify the CSV is created with the correct participant data."""
|
| 93 |
-
key_dir = tmp_path / "csv_test"
|
| 94 |
os.makedirs(key_dir)
|
| 95 |
-
csv_path = os.path.join(key_dir, "authorized_supernodes.csv")
|
| 96 |
|
| 97 |
participants = [
|
| 98 |
-
("p1", "
|
| 99 |
-
("p2", "
|
| 100 |
]
|
| 101 |
-
rebuild_authorized_keys_csv(key_dir, participants)
|
| 102 |
|
|
|
|
| 103 |
with open(csv_path, "r") as f:
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
key_dir = tmp_path / "csv_test"
|
| 115 |
os.makedirs(key_dir)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 116 |
csv_path = os.path.join(key_dir, "authorized_supernodes.csv")
|
|
|
|
|
|
|
| 117 |
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
|
|
|
| 121 |
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 125 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 126 |
with open(csv_path, "r") as f:
|
| 127 |
-
|
| 128 |
-
_ = next(reader)
|
| 129 |
-
rows = list(reader)
|
| 130 |
|
| 131 |
-
|
| 132 |
-
assert
|
| 133 |
-
assert
|
|
|
|
| 1 |
import os
|
|
|
|
| 2 |
import stat
|
| 3 |
import pytest
|
| 4 |
from cryptography.hazmat.primitives import serialization
|
|
|
|
| 24 |
AuthKeyGenerator(key_dir=str(key_dir))
|
| 25 |
assert os.path.exists(key_dir)
|
| 26 |
|
| 27 |
+
def test_generate_participant_keys_creates_files_and_returns_openssh_with_comment(
|
| 28 |
self, key_generator
|
| 29 |
):
|
| 30 |
"""
|
| 31 |
+
Verify the main method generates files and returns the public key
|
| 32 |
+
in the correct OpenSSH format including a comment.
|
| 33 |
"""
|
| 34 |
participant_id = "participant_01"
|
| 35 |
+
priv_path, pub_path, pub_ssh_string = key_generator.generate_participant_keys(
|
| 36 |
participant_id
|
| 37 |
)
|
| 38 |
|
| 39 |
+
# 1. Check file existence and permissions
|
| 40 |
assert os.path.exists(priv_path)
|
| 41 |
assert os.path.exists(pub_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
if os.name != "nt":
|
| 43 |
file_mode = stat.S_IMODE(os.stat(priv_path).st_mode)
|
| 44 |
assert file_mode == 0o600
|
| 45 |
|
| 46 |
+
# 2. Verify that the returned string has three parts (type, key, comment)
|
| 47 |
+
assert pub_ssh_string.startswith("ecdsa-sha2-nistp384")
|
| 48 |
+
assert pub_ssh_string.endswith(participant_id)
|
| 49 |
+
assert len(pub_ssh_string.split(" ")) == 3
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
|
| 51 |
+
# 3. Verify that the public key file can be loaded as an SSH key
|
| 52 |
+
with open(pub_path, "rb") as f:
|
| 53 |
+
public_key_from_file = serialization.load_ssh_public_key(f.read())
|
| 54 |
+
assert isinstance(public_key_from_file, ec.EllipticCurvePublicKey)
|
|
|
|
|
|
|
|
|
|
| 55 |
|
| 56 |
|
| 57 |
+
class TestRebuildAuthorizedKeysFile:
|
| 58 |
"""Test suite for the rebuild_authorized_keys_csv function."""
|
| 59 |
|
| 60 |
+
def test_rebuild_creates_file_with_only_newline_for_empty_list(self, tmp_path):
|
| 61 |
+
"""Verify an empty participant list results in a file with just a newline."""
|
| 62 |
+
key_dir = tmp_path / "keys_test"
|
| 63 |
os.makedirs(key_dir)
|
| 64 |
+
rebuild_authorized_keys_csv(str(key_dir), [])
|
| 65 |
|
| 66 |
+
csv_path = os.path.join(key_dir, "authorized_supernodes.csv")
|
|
|
|
|
|
|
| 67 |
with open(csv_path, "r") as f:
|
| 68 |
+
content = f.read()
|
| 69 |
+
assert content == "\n"
|
| 70 |
+
|
| 71 |
+
def test_rebuild_writes_correct_single_line_format(self, tmp_path):
|
| 72 |
+
"""Verify the file is created in the single-line, comma-separated format."""
|
| 73 |
+
key_dir = tmp_path / "keys_test"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
os.makedirs(key_dir)
|
|
|
|
| 75 |
|
| 76 |
participants = [
|
| 77 |
+
("p1", "ecdsa-sha2-nistp384 AAAA...key1 p1"),
|
| 78 |
+
("p2", "ecdsa-sha2-nistp384 AAAA...key2 p2"),
|
| 79 |
]
|
| 80 |
+
rebuild_authorized_keys_csv(str(key_dir), participants)
|
| 81 |
|
| 82 |
+
csv_path = os.path.join(key_dir, "authorized_supernodes.csv")
|
| 83 |
with open(csv_path, "r") as f:
|
| 84 |
+
content = f.read().strip()
|
| 85 |
+
|
| 86 |
+
expected_content = (
|
| 87 |
+
"ecdsa-sha2-nistp384 AAAA...key1 p1,ecdsa-sha2-nistp384 AAAA...key2 p2"
|
| 88 |
+
)
|
| 89 |
+
assert content == expected_content
|
| 90 |
+
|
| 91 |
+
def test_rebuild_overwrites_existing_file(self, tmp_path):
|
| 92 |
+
"""Verify that an existing file is correctly overwritten."""
|
| 93 |
+
key_dir = tmp_path / "keys_test"
|
|
|
|
| 94 |
os.makedirs(key_dir)
|
| 95 |
+
|
| 96 |
+
# Use dummy data that matches the expected OpenSSH format
|
| 97 |
+
initial_participants = [("old_p1", "ecdsa-sha2-nistp384 old_key_1 old_p1")]
|
| 98 |
+
rebuild_authorized_keys_csv(str(key_dir), initial_participants)
|
| 99 |
+
|
| 100 |
+
new_participants = [
|
| 101 |
+
("new_p1", "ecdsa-sha2-nistp384 new_key_1 new_p1"),
|
| 102 |
+
("new_p2", "ecdsa-sha2-nistp384 new_key_2 new_p2"),
|
| 103 |
+
]
|
| 104 |
+
rebuild_authorized_keys_csv(str(key_dir), new_participants)
|
| 105 |
+
|
| 106 |
csv_path = os.path.join(key_dir, "authorized_supernodes.csv")
|
| 107 |
+
with open(csv_path, "r") as f:
|
| 108 |
+
content = f.read().strip()
|
| 109 |
|
| 110 |
+
expected_content = (
|
| 111 |
+
"ecdsa-sha2-nistp384 new_key_1 new_p1,ecdsa-sha2-nistp384 new_key_2 new_p2"
|
| 112 |
+
)
|
| 113 |
+
assert content == expected_content
|
| 114 |
|
| 115 |
+
def test_rebuild_sanitizes_pem_keys_to_ssh_format(self, tmp_path):
|
| 116 |
+
"""
|
| 117 |
+
Tests the self-healing capability of the rebuild function to convert
|
| 118 |
+
old PEM keys from the database into the correct OpenSSH format.
|
| 119 |
+
"""
|
| 120 |
+
key_dir = tmp_path / "keys_test"
|
| 121 |
+
os.makedirs(key_dir)
|
| 122 |
|
| 123 |
+
# Generate a real key pair to get a valid PEM string
|
| 124 |
+
private_key = ec.generate_private_key(ec.SECP384R1())
|
| 125 |
+
public_key = private_key.public_key()
|
| 126 |
+
pem_key = public_key.public_bytes(
|
| 127 |
+
encoding=serialization.Encoding.PEM,
|
| 128 |
+
format=serialization.PublicFormat.SubjectPublicKeyInfo,
|
| 129 |
+
).decode("utf-8")
|
| 130 |
+
|
| 131 |
+
participants = [("p1_pem", pem_key)]
|
| 132 |
+
rebuild_authorized_keys_csv(str(key_dir), participants)
|
| 133 |
+
|
| 134 |
+
csv_path = os.path.join(key_dir, "authorized_supernodes.csv")
|
| 135 |
with open(csv_path, "r") as f:
|
| 136 |
+
content = f.read().strip()
|
|
|
|
|
|
|
| 137 |
|
| 138 |
+
# Verify the output is now in the correct OpenSSH format with the comment
|
| 139 |
+
assert content.startswith("ecdsa-sha2-nistp384")
|
| 140 |
+
assert content.endswith("p1_pem")
|
tests/test_blossomfile.py
CHANGED
|
@@ -50,7 +50,7 @@ def test_create_blossomfile_success(tmp_path, dummy_credential_files):
|
|
| 50 |
|
| 51 |
# 1. Verify the file was created at the correct path
|
| 52 |
assert os.path.exists(blossomfile_path)
|
| 53 |
-
assert blossomfile_path == str(output_dir / f"{participant_id}.
|
| 54 |
|
| 55 |
# 2. Verify the contents of the zip archive
|
| 56 |
with zipfile.ZipFile(blossomfile_path, "r") as zf:
|
|
|
|
| 50 |
|
| 51 |
# 1. Verify the file was created at the correct path
|
| 52 |
assert os.path.exists(blossomfile_path)
|
| 53 |
+
assert blossomfile_path == str(output_dir / f"{participant_id}-blossomfile.zip")
|
| 54 |
|
| 55 |
# 2. Verify the contents of the zip archive
|
| 56 |
with zipfile.ZipFile(blossomfile_path, "r") as zf:
|
tests/test_federation.py
CHANGED
|
@@ -34,7 +34,7 @@ class TestCheckParticipantStatus:
|
|
| 34 |
"""Verify successful registration for a new user."""
|
| 35 |
mock_mail.return_value = (True, "")
|
| 36 |
approved, message, download = fed.check_participant_status(
|
| 37 |
-
"new_user", "
|
| 38 |
)
|
| 39 |
assert approved is False
|
| 40 |
assert download is None
|
|
@@ -43,7 +43,7 @@ class TestCheckParticipantStatus:
|
|
| 43 |
# Verify the user was added to the database
|
| 44 |
request = db_session.query(Request).filter_by(hf_handle="new_user").first()
|
| 45 |
assert request is not None
|
| 46 |
-
assert request.email == "
|
| 47 |
|
| 48 |
def test_new_user_invalid_email(self, db_session, mock_settings):
|
| 49 |
"""Verify registration fails with an invalid email."""
|
|
@@ -108,7 +108,7 @@ class TestCheckParticipantStatus:
|
|
| 108 |
assert download is None
|
| 109 |
assert message == "mock_activation_invalid_md"
|
| 110 |
|
| 111 |
-
def
|
| 112 |
"""Verify the status check for an approved user."""
|
| 113 |
approved_user = Request(
|
| 114 |
participant_id="PID456",
|
|
@@ -125,9 +125,9 @@ class TestCheckParticipantStatus:
|
|
| 125 |
approved, message, download = fed.check_participant_status(
|
| 126 |
"approved_user", "approved@example.com", "GHIJKL"
|
| 127 |
)
|
| 128 |
-
assert approved is
|
| 129 |
-
assert download is not None
|
| 130 |
-
assert "
|
| 131 |
|
| 132 |
|
| 133 |
class TestManageRequest:
|
|
@@ -147,7 +147,10 @@ class TestManageRequest:
|
|
| 147 |
|
| 148 |
success, message = fed.manage_request("PENDING1", "10", "approve")
|
| 149 |
assert success is True
|
| 150 |
-
assert
|
|
|
|
|
|
|
|
|
|
| 151 |
|
| 152 |
# Verify status in DB
|
| 153 |
updated_user = (
|
|
@@ -182,7 +185,7 @@ class TestManageRequest:
|
|
| 182 |
db_session.commit()
|
| 183 |
success, message = fed.manage_request("PENDING3", "", "deny")
|
| 184 |
assert success is True
|
| 185 |
-
assert "
|
| 186 |
|
| 187 |
# Verify status in DB
|
| 188 |
updated_user = (
|
|
|
|
| 34 |
"""Verify successful registration for a new user."""
|
| 35 |
mock_mail.return_value = (True, "")
|
| 36 |
approved, message, download = fed.check_participant_status(
|
| 37 |
+
"new_user", "hello@ethicalabs.ai", ""
|
| 38 |
)
|
| 39 |
assert approved is False
|
| 40 |
assert download is None
|
|
|
|
| 43 |
# Verify the user was added to the database
|
| 44 |
request = db_session.query(Request).filter_by(hf_handle="new_user").first()
|
| 45 |
assert request is not None
|
| 46 |
+
assert request.email == "hello@ethicalabs.ai"
|
| 47 |
|
| 48 |
def test_new_user_invalid_email(self, db_session, mock_settings):
|
| 49 |
"""Verify registration fails with an invalid email."""
|
|
|
|
| 108 |
assert download is None
|
| 109 |
assert message == "mock_activation_invalid_md"
|
| 110 |
|
| 111 |
+
def test_status_check_approved_unmanaged(self, db_session, mock_settings):
|
| 112 |
"""Verify the status check for an approved user."""
|
| 113 |
approved_user = Request(
|
| 114 |
participant_id="PID456",
|
|
|
|
| 125 |
approved, message, download = fed.check_participant_status(
|
| 126 |
"approved_user", "approved@example.com", "GHIJKL"
|
| 127 |
)
|
| 128 |
+
assert approved is False
|
| 129 |
+
assert download is None # fed.manage_request is not used. download is None.
|
| 130 |
+
assert "An error occurred" in message
|
| 131 |
|
| 132 |
|
| 133 |
class TestManageRequest:
|
|
|
|
| 147 |
|
| 148 |
success, message = fed.manage_request("PENDING1", "10", "approve")
|
| 149 |
assert success is True
|
| 150 |
+
assert (
|
| 151 |
+
"Participant PENDING1 approved. Keys generated and registry updated."
|
| 152 |
+
in message
|
| 153 |
+
)
|
| 154 |
|
| 155 |
# Verify status in DB
|
| 156 |
updated_user = (
|
|
|
|
| 185 |
db_session.commit()
|
| 186 |
success, message = fed.manage_request("PENDING3", "", "deny")
|
| 187 |
assert success is True
|
| 188 |
+
assert "Participant PENDING3 denied. Their access has been revoked." in message
|
| 189 |
|
| 190 |
# Verify status in DB
|
| 191 |
updated_user = (
|
uv.lock
CHANGED
|
The diff for this file is too large to render.
See raw diff
|
|
|