mrs83 commited on
Commit
3a1c55b
·
unverified ·
1 Parent(s): c3cb0fd

TLS certs and external superlink support (plus initial work on authentication keys and .blossomfile) (#3)

Browse files

* Adds script for TLS certs generation.
* Adds support for SUPERLINK_MODE (internal or external)
* Updates UI to support external superlink.
* Adds docker-compose.yaml file for self-signed local certs testing.
* Adds more tests
* Updates README.md
* Adds .blossomfile spec: a single file for the participant, with the configuration needed to connect to a federation. Work in progress!

README.md CHANGED
@@ -77,6 +77,50 @@ python -m blossomtune_gradio
77
 
78
  The application will be accessible via a local URL provided by Gradio.
79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  ## Usage Guide
81
 
82
  ### For Participants
 
77
 
78
  The application will be accessible via a local URL provided by Gradio.
79
 
80
+ ## Generating Self-Signed Certificates for Local Development (Docker)
81
+
82
+ When running the application with `docker-compose`, the `superlink` service requires TLS certificates to enable secure connections.
83
+
84
+ For local development, you can generate a self-signed Certificate Authority (CA) and a `localhost` certificate using the provided script.
85
+
86
+ **Step 1: Run the Certificate Generator**
87
+
88
+ Execute the interactive TLS generation script located in the `blossomtune_gradio` directory:
89
+
90
+ ```bash
91
+ python3 -m blossomtune_gradio.generate_tls
92
+ ```
93
+
94
+ **Step 2: Choose the Development Option**
95
+
96
+ When prompted, select option **1** to generate a self-signed certificate for `localhost`.
97
+
98
+ ```text
99
+ ===== BlossomTune TLS Certificate Generator =====
100
+ Select an option:
101
+ 1. Generate a self-signed 'localhost' certificate (for Development)
102
+ 2. Generate a server certificate using the main CA (for Production)
103
+ 3. Exit
104
+ Enter your choice [1]: 1
105
+ ```
106
+
107
+ The script will create a new directory named `certificates_localhost` containing the generated CA (`ca.crt`) and the server certificate files (`server.key`, `server.crt`, `server.pem`).
108
+
109
+ **Step 3: Copy Certificates to the Data Directory**
110
+
111
+ The `docker-compose.yml` file is configured to mount a local `./data/certs` directory into the `superlink` container. You must copy the essential certificate files into this location:
112
+
113
+ ```bash
114
+ cp certificates_localhost/ca.crt ./data/certs/
115
+ cp certificates_localhost/server.key ./data/certs/
116
+ cp certificates_localhost/server.pem ./data/certs/
117
+ ```
118
+
119
+ Once these files are in place, you can start the services using `docker compose up`.
120
+
121
+ The Superlink will automatically find and use these certificates to secure its connections.
122
+
123
+
124
  ## Usage Guide
125
 
126
  ### For Participants
blossomtune_gradio/auth_keys.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import csv
3
+ import logging
4
+ from typing import List, Tuple
5
+ from cryptography.hazmat.primitives.asymmetric import ec
6
+ from cryptography.hazmat.primitives import serialization
7
+
8
+ log = logging.getLogger(__name__)
9
+
10
+
11
+ def rebuild_authorized_keys_csv(
12
+ key_dir: str, authorized_participants: List[Tuple[str, str]]
13
+ ):
14
+ """
15
+ Overwrites the public key CSV with a fresh list from a trusted source.
16
+
17
+ This function should be called before starting the Flower SuperLink to ensure
18
+ the list of authorized nodes is always in sync with the database.
19
+
20
+ Args:
21
+ key_dir (str): The directory where the CSV will be stored.
22
+ authorized_participants (List[Tuple[str, str]]): A list of tuples,
23
+ where each tuple contains (participant_id, public_key_pem).
24
+ """
25
+ csv_path = os.path.join(key_dir, "authorized_supernodes.csv")
26
+ log.info(f"Rebuilding authorized keys CSV at: {csv_path}")
27
+
28
+ with open(csv_path, "w", newline="") as f:
29
+ writer = csv.writer(f)
30
+ writer.writerow(["participant_id", "public_key_pem"])
31
+ for participant_id, public_key_pem in authorized_participants:
32
+ writer.writerow([participant_id, public_key_pem])
33
+
34
+ log.info(
35
+ f"Successfully rebuilt {csv_path} with {len(authorized_participants)} keys."
36
+ )
37
+
38
+
39
+ class AuthKeyGenerator:
40
+ """
41
+ Handles the generation of Elliptic Curve (EC) key pairs for Flower
42
+ SuperNode authentication.
43
+ """
44
+
45
+ def __init__(self, key_dir: str = "keys"):
46
+ """
47
+ Initializes the generator and ensures the key directory exists.
48
+
49
+ Args:
50
+ key_dir (str): The directory where key files will be stored.
51
+ """
52
+ self.key_dir = key_dir
53
+ os.makedirs(self.key_dir, exist_ok=True)
54
+ log.info(f"Authentication key directory set to: {self.key_dir}")
55
+
56
+ def _generate_key_pair(self) -> ec.EllipticCurvePrivateKey:
57
+ """Generates a single EC private key using the SECP384R1 curve."""
58
+ return ec.generate_private_key(ec.SECP384R1())
59
+
60
+ def _save_private_key(
61
+ self, private_key: ec.EllipticCurvePrivateKey, participant_id: str
62
+ ) -> str:
63
+ """Saves the private key to a file with secure permissions."""
64
+ priv_key_path = os.path.join(self.key_dir, f"{participant_id}.key")
65
+ with open(priv_key_path, "wb") as f:
66
+ f.write(
67
+ private_key.private_bytes(
68
+ encoding=serialization.Encoding.PEM,
69
+ format=serialization.PrivateFormat.PKCS8,
70
+ encryption_algorithm=serialization.NoEncryption(),
71
+ )
72
+ )
73
+ # Set file permissions to read/write for owner only (600)
74
+ os.chmod(priv_key_path, 0o600)
75
+ log.info(f"Private key for {participant_id} saved securely to {priv_key_path}")
76
+ return priv_key_path
77
+
78
+ def _save_public_key_file(
79
+ self, private_key: ec.EllipticCurvePrivateKey, participant_id: str
80
+ ) -> str:
81
+ """Saves the public key to a .pub file."""
82
+ public_key = private_key.public_key()
83
+ pub_key_path = os.path.join(self.key_dir, f"{participant_id}.pub")
84
+ with open(pub_key_path, "wb") as f:
85
+ f.write(
86
+ public_key.public_bytes(
87
+ encoding=serialization.Encoding.PEM,
88
+ format=serialization.PublicFormat.SubjectPublicKeyInfo,
89
+ )
90
+ )
91
+ log.info(f"Public key for {participant_id} saved to {pub_key_path}")
92
+ return pub_key_path
93
+
94
+ def generate_participant_keys(self, participant_id: str) -> Tuple[str, str, str]:
95
+ """
96
+ Generates and saves a new EC key pair for a participant.
97
+
98
+ This is the main public method to call when a new participant is approved.
99
+
100
+ Args:
101
+ participant_id (str): A unique identifier for the participant.
102
+
103
+ Returns:
104
+ A tuple containing:
105
+ - The file path to the generated private key.
106
+ - The file path to the generated public key.
107
+ - The public key as a PEM-encoded string (for database storage).
108
+ """
109
+ private_key = self._generate_key_pair()
110
+ public_key = private_key.public_key()
111
+
112
+ priv_key_path = self._save_private_key(private_key, participant_id)
113
+ pub_key_path = self._save_public_key_file(private_key, participant_id)
114
+
115
+ public_key_pem = public_key.public_bytes(
116
+ encoding=serialization.Encoding.PEM,
117
+ format=serialization.PublicFormat.SubjectPublicKeyInfo,
118
+ ).decode("utf-8")
119
+
120
+ return (priv_key_path, pub_key_path, public_key_pem)
blossomtune_gradio/blossomfile.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import zipfile
4
+ import logging
5
+ from typing import Dict, Any
6
+
7
+ # Configure logging for the module
8
+ log = logging.getLogger(__name__)
9
+
10
+
11
+ def create_blossomfile(
12
+ participant_id: str,
13
+ output_dir: str,
14
+ ca_cert_path: str,
15
+ auth_key_path: str,
16
+ auth_pub_path: str,
17
+ superlink_address: str,
18
+ partition_id: int,
19
+ num_partitions: int,
20
+ ) -> str:
21
+ """
22
+ Generates a .blossomfile for a participant.
23
+
24
+ This function packages all necessary credentials and configuration into a
25
+ single, portable .zip archive that a participant can use to easily
26
+ connect to the federation.
27
+
28
+ Args:
29
+ participant_id: The unique identifier for the participant.
30
+ output_dir: The directory where the final .blossomfile will be saved.
31
+ ca_cert_path: Path to the CA public certificate (`ca.crt`).
32
+ auth_key_path: Path to the participant's private EC key (`auth.key`).
33
+ auth_pub_path: Path to the participant's public EC key (`auth.pub`).
34
+ superlink_address: The public address of the Flower SuperLink.
35
+ partition_id: The data partition ID assigned to the participant.
36
+ num_partitions: The total number of partitions in the federation.
37
+
38
+ Returns:
39
+ The full path to the generated .blossomfile.
40
+ """
41
+ os.makedirs(output_dir, exist_ok=True)
42
+ blossomfile_path = os.path.join(output_dir, f"{participant_id}.blossomfile")
43
+ log.info(f"Creating Blossomfile for {participant_id} at {blossomfile_path}")
44
+
45
+ # 1. Create the blossom.json configuration data
46
+ blossom_config: Dict[str, Any] = {
47
+ "superlink_address": superlink_address,
48
+ "node_config": {
49
+ "partition-id": partition_id,
50
+ "num-partitions": num_partitions,
51
+ },
52
+ }
53
+
54
+ # 2. Define the files to be included in the archive
55
+ files_to_add = {
56
+ ca_cert_path: "ca.crt",
57
+ auth_key_path: "auth.key",
58
+ auth_pub_path: "auth.pub",
59
+ }
60
+
61
+ # 3. Create the zip archive
62
+ try:
63
+ with zipfile.ZipFile(blossomfile_path, "w", zipfile.ZIP_DEFLATED) as zf:
64
+ # Add the configuration file
65
+ zf.writestr("blossom.json", json.dumps(blossom_config, indent=2))
66
+ log.info("Added blossom.json to archive.")
67
+
68
+ # Add the certificate and key files
69
+ for src_path, arc_name in files_to_add.items():
70
+ if os.path.exists(src_path):
71
+ zf.write(src_path, arcname=arc_name)
72
+ log.info(f"Added {arc_name} to archive from {src_path}.")
73
+ else:
74
+ log.error(f"Credential file not found: {src_path}. Aborting.")
75
+ raise FileNotFoundError(
76
+ f"Required credential file not found: {src_path}"
77
+ )
78
+
79
+ except Exception as e:
80
+ log.critical(f"Failed to create Blossomfile: {e}")
81
+ # Clean up partially created file on failure
82
+ if os.path.exists(blossomfile_path):
83
+ os.remove(blossomfile_path)
84
+ raise
85
+
86
+ log.info(f"Successfully created Blossomfile: {blossomfile_path}")
87
+ return blossomfile_path
blossomtune_gradio/config.py CHANGED
@@ -10,7 +10,9 @@ SPACE_ID = os.getenv("SPACE_ID", "ethicalabs/BlossomTune-Orchestrator")
10
  SPACE_OWNER = os.getenv("SPACE_OWNER", SPACE_ID.split("/")[0] if SPACE_ID else None)
11
 
12
  # Use persistent storage if available
13
- DB_PATH = "/data/federation.db" if os.path.isdir("/data") else "federation.db"
 
 
14
  MAX_NUM_NODES = int(os.getenv("MAX_NUM_NODES", "20"))
15
  SMTP_SENDER = os.getenv("SMTP_SENDER", "hello@ethicalabs.ai")
16
  SMTP_SERVER = os.getenv("SMTP_SERVER", "localhost")
@@ -19,5 +21,19 @@ SMTP_REQUIRE_TLS = util.strtobool(os.getenv("SMTP_REQUIRE_TLS", "false"))
19
  SMTP_USER = os.getenv("SMTP_USER", "")
20
  SMTP_PASSWORD = os.getenv("SMTP_PASSWORD", "")
21
  EMAIL_PROVIDER = os.getenv("EMAIL_PROVIDER", "smtp")
22
- SUPERLINK_HOST = os.getenv("SUPERLINK_HOST", "")
23
  SUPERLINK_PORT = int(os.getenv("SUPERLINK_PORT", 9092))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  SPACE_OWNER = os.getenv("SPACE_OWNER", SPACE_ID.split("/")[0] if SPACE_ID else None)
11
 
12
  # Use persistent storage if available
13
+ DB_PATH = (
14
+ "/data/db/federation.db" if os.path.isdir("/data/db") else "./data/db/federation.db"
15
+ )
16
  MAX_NUM_NODES = int(os.getenv("MAX_NUM_NODES", "20"))
17
  SMTP_SENDER = os.getenv("SMTP_SENDER", "hello@ethicalabs.ai")
18
  SMTP_SERVER = os.getenv("SMTP_SERVER", "localhost")
 
21
  SMTP_USER = os.getenv("SMTP_USER", "")
22
  SMTP_PASSWORD = os.getenv("SMTP_PASSWORD", "")
23
  EMAIL_PROVIDER = os.getenv("EMAIL_PROVIDER", "smtp")
24
+ SUPERLINK_HOST = os.getenv("SUPERLINK_HOST", "127.0.0.1:9092")
25
  SUPERLINK_PORT = int(os.getenv("SUPERLINK_PORT", 9092))
26
+ SUPERLINK_MODE = os.getenv("SUPERLINK_MODE", "internal").lower() # Or external
27
+
28
+ # TLS root cert path. For production only.
29
+ TLS_CERT_DIR = os.getenv("TLS_CERT_DIR", "./certs/")
30
+ TLS_CA_KEY_PATH = os.getenv("TLS_CA_KEY_PATH", False)
31
+ TLS_CA_CERT_PATH = os.getenv("TLS_CA_CERT_PATH", False)
32
+
33
+ # BlossomTune cert - To be distributed to the participants (supernodes).
34
+ BLOSSOMTUNE_TLS_CERT_PATH = os.getenv(
35
+ "BLOSSOMTUNE_TLS_CERT_PATH",
36
+ "/data/certs/server.crt"
37
+ if os.path.isdir("/data/certs")
38
+ else "./data/certs/server.crt",
39
+ )
blossomtune_gradio/generate_tls.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+ import logging
4
+
5
+
6
+ from blossomtune_gradio.tls import TLSGenerator
7
+ from blossomtune_gradio import config as cfg
8
+
9
+
10
+ # Configure basic logging for the script
11
+ logging.basicConfig(
12
+ level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
13
+ )
14
+
15
+
16
+ def generate_dev_cert():
17
+ """Generates a self-signed certificate for localhost development."""
18
+ try:
19
+ print("\n--- Generating self-signed certificate for localhost ---")
20
+ cert_dir = "certificates_localhost"
21
+ if os.path.exists(cert_dir):
22
+ shutil.rmtree(cert_dir)
23
+
24
+ generator = TLSGenerator(cert_dir=cert_dir)
25
+ # Note: No existing CA is passed, so a new one will be created.
26
+ generator.generate_server_certificate(
27
+ common_name="localhost", sans=["localhost", "127.0.0.1"]
28
+ )
29
+ print(f"\n✅ Success! Self-signed CA and server cert created in '{cert_dir}'.")
30
+ except Exception as e:
31
+ print(f"\n❌ An error occurred: {e}")
32
+
33
+
34
+ def generate_prod_cert():
35
+ """Generates a server certificate signed by the CA specified in config."""
36
+ if not cfg.TLS_CA_KEY_PATH or not cfg.TLS_CA_CERT_PATH:
37
+ print(
38
+ "\n❌ Error: TLS_CA_KEY_PATH and TLS_CA_CERT_PATH are not set in your config."
39
+ )
40
+ print("Please configure the paths to your main CA certificate and key.")
41
+ return
42
+
43
+ try:
44
+ print(
45
+ f"\n--- Generating production certificate signed by {cfg.TLS_CA_CERT_PATH} ---"
46
+ )
47
+ common_name = input(
48
+ "Enter the primary domain name for the server (e.g., fl.mydomain.com): "
49
+ ).strip()
50
+ if not common_name:
51
+ print("Error: Domain name cannot be empty.")
52
+ return
53
+
54
+ generator = TLSGenerator(cert_dir=cfg.TLS_CERT_DIR)
55
+ generator.generate_server_certificate(
56
+ common_name=common_name,
57
+ ca_key_path=cfg.TLS_CA_KEY_PATH,
58
+ ca_cert_path=cfg.TLS_CA_CERT_PATH,
59
+ )
60
+ print(
61
+ f"\n✅ Success! Server certificate and key created in '{cfg.TLS_CERT_DIR}'."
62
+ )
63
+ except Exception as e:
64
+ print(f"\n❌ An error occurred: {e}")
65
+
66
+
67
+ def main():
68
+ """Main function to run the interactive menu."""
69
+ while True:
70
+ print("\n===== BlossomTune TLS Certificate Generator =====")
71
+ print("Select an option:")
72
+ print(" 1. Generate a self-signed 'localhost' certificate (for Development)")
73
+ print(" 2. Generate a server certificate using the main CA (for Production)")
74
+ print(" 3. Exit")
75
+
76
+ choice = input("Enter your choice [1]: ").strip() or "1"
77
+
78
+ if choice == "1":
79
+ generate_dev_cert()
80
+ elif choice == "2":
81
+ generate_prod_cert()
82
+ elif choice == "3":
83
+ print("Exiting.")
84
+ break
85
+ else:
86
+ print("Invalid choice. Please try again.")
87
+
88
+
89
+ if __name__ == "__main__":
90
+ main()
blossomtune_gradio/processing.py CHANGED
@@ -37,6 +37,11 @@ def run_process(command, process_key):
37
 
38
 
39
  def start_superlink():
 
 
 
 
 
40
  if process_store["superlink"] and process_store["superlink"].poll() is None:
41
  return False, "Superlink process is already running."
42
  command = [shutil.which("flower-superlink"), "--insecure"]
 
37
 
38
 
39
  def start_superlink():
40
+ # Do not start an internal process if in external mode.
41
+ if cfg.SUPERLINK_MODE == "external":
42
+ log.warning("start_superlink called while in external mode. Operation aborted.")
43
+ return False, "Application is in external Superlink mode."
44
+
45
  if process_store["superlink"] and process_store["superlink"].poll() is None:
46
  return False, "Superlink process is already running."
47
  command = [shutil.which("flower-superlink"), "--insecure"]
blossomtune_gradio/tls.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import datetime
3
+ import logging
4
+ from ipaddress import ip_address
5
+ from cryptography import x509
6
+ from cryptography.x509.oid import NameOID
7
+ from cryptography.hazmat.primitives import hashes, serialization
8
+ from cryptography.hazmat.primitives.asymmetric import rsa
9
+ from cryptography.hazmat.backends import default_backend
10
+
11
+ log = logging.getLogger(__name__)
12
+
13
+
14
+ class TLSGenerator:
15
+ """
16
+ A class to handle the generation of TLS certificates, keys, and CSRs.
17
+ This class contains the core cryptographic logic and is separated from
18
+ any user interface.
19
+ """
20
+
21
+ def __init__(self, cert_dir: str = "certificates"):
22
+ self.cert_dir = cert_dir
23
+ os.makedirs(self.cert_dir, exist_ok=True)
24
+ log.info(f"Certificate directory set to: {self.cert_dir}")
25
+
26
+ def _generate_private_key(self, filename: str) -> rsa.RSAPrivateKey:
27
+ """Generates and saves a 4096-bit RSA private key."""
28
+ log.info(f"Generating private key: {filename}...")
29
+ private_key = rsa.generate_private_key(
30
+ public_exponent=65537, key_size=4096, backend=default_backend()
31
+ )
32
+ key_path = os.path.join(self.cert_dir, filename)
33
+ with open(key_path, "wb") as f:
34
+ f.write(
35
+ private_key.private_bytes(
36
+ encoding=serialization.Encoding.PEM,
37
+ format=serialization.PrivateFormat.TraditionalOpenSSL,
38
+ encryption_algorithm=serialization.NoEncryption(),
39
+ )
40
+ )
41
+ log.info(f"Private key saved to {key_path}")
42
+ return private_key
43
+
44
+ def create_ca(self) -> tuple[rsa.RSAPrivateKey, x509.Certificate]:
45
+ """Generates a new self-signed CA key and certificate."""
46
+ log.info("Generating a new self-signed CA...")
47
+ ca_private_key = self._generate_private_key("ca.key")
48
+
49
+ subject = issuer = x509.Name(
50
+ [
51
+ x509.NameAttribute(NameOID.COUNTRY_NAME, "DE"),
52
+ x509.NameAttribute(NameOID.ORGANIZATION_NAME, "BlossomTune CA"),
53
+ x509.NameAttribute(NameOID.COMMON_NAME, "BlossomTune Self-Signed CA"),
54
+ ]
55
+ )
56
+ cert_builder = (
57
+ x509.CertificateBuilder()
58
+ .subject_name(subject)
59
+ .issuer_name(issuer)
60
+ .public_key(ca_private_key.public_key())
61
+ .serial_number(x509.random_serial_number())
62
+ .not_valid_before(datetime.datetime.utcnow())
63
+ .not_valid_after(datetime.datetime.utcnow() + datetime.timedelta(days=730))
64
+ .add_extension(
65
+ x509.BasicConstraints(ca=True, path_length=None), critical=True
66
+ )
67
+ )
68
+ ca_cert = cert_builder.sign(ca_private_key, hashes.SHA256(), default_backend())
69
+
70
+ ca_cert_path = os.path.join(self.cert_dir, "ca.crt")
71
+ with open(ca_cert_path, "wb") as f:
72
+ f.write(ca_cert.public_bytes(serialization.Encoding.PEM))
73
+ log.info(f"CA certificate saved to {ca_cert_path}")
74
+ return ca_private_key, ca_cert
75
+
76
+ def generate_server_certificate(
77
+ self,
78
+ common_name: str,
79
+ sans: list[str] | None = None,
80
+ ca_key_path: str | None = None,
81
+ ca_cert_path: str | None = None,
82
+ ):
83
+ """
84
+ Generates a server key, signs a certificate, and creates a combined server.pem file.
85
+ - If a CA is provided, it uses it.
86
+ - Otherwise, it generates a new self-signed CA first.
87
+ """
88
+ server_private_key = self._generate_private_key("server.key")
89
+
90
+ if ca_key_path and ca_cert_path:
91
+ log.info(f"Loading existing CA from {ca_cert_path}")
92
+ with open(ca_key_path, "rb") as f:
93
+ ca_private_key = serialization.load_pem_private_key(
94
+ f.read(), password=None
95
+ )
96
+ with open(ca_cert_path, "rb") as f:
97
+ ca_cert = x509.load_pem_x509_certificate(f.read())
98
+ else:
99
+ ca_private_key, ca_cert = self.create_ca()
100
+
101
+ subject = x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, common_name)])
102
+ san_list = [x509.DNSName(common_name)]
103
+ if sans:
104
+ for name in set(sans):
105
+ try:
106
+ san_list.append(x509.IPAddress(ip_address(name)))
107
+ except ValueError:
108
+ san_list.append(x509.DNSName(name))
109
+
110
+ builder = (
111
+ x509.CertificateBuilder()
112
+ .subject_name(subject)
113
+ .issuer_name(ca_cert.subject)
114
+ .public_key(server_private_key.public_key())
115
+ .serial_number(x509.random_serial_number())
116
+ .not_valid_before(datetime.datetime.utcnow())
117
+ .not_valid_after(datetime.datetime.utcnow() + datetime.timedelta(days=365))
118
+ .add_extension(x509.SubjectAlternativeName(san_list), critical=False)
119
+ )
120
+ server_cert = builder.sign(ca_private_key, hashes.SHA256(), default_backend())
121
+
122
+ # Save the individual server certificate (.crt)
123
+ server_cert_path = os.path.join(self.cert_dir, "server.crt")
124
+ server_cert_bytes = server_cert.public_bytes(serialization.Encoding.PEM)
125
+ with open(server_cert_path, "wb") as f:
126
+ f.write(server_cert_bytes)
127
+ log.info(f"Server certificate saved to {server_cert_path}")
128
+
129
+ # Create the combined server.pem file for Flower Superlink
130
+ server_key_bytes = server_private_key.private_bytes(
131
+ encoding=serialization.Encoding.PEM,
132
+ format=serialization.PrivateFormat.TraditionalOpenSSL,
133
+ encryption_algorithm=serialization.NoEncryption(),
134
+ )
135
+ server_pem_path = os.path.join(self.cert_dir, "server.pem")
136
+ with open(server_pem_path, "wb") as f:
137
+ f.write(server_cert_bytes)
138
+ f.write(server_key_bytes)
139
+ log.info(f"Server PEM file (cert + key) saved to {server_pem_path}")
blossomtune_gradio/ui/callbacks.py CHANGED
@@ -8,6 +8,7 @@ from blossomtune_gradio.logs import log
8
  from blossomtune_gradio import federation as fed
9
  from blossomtune_gradio import processing
10
  from blossomtune_gradio.settings import settings
 
11
 
12
  from . import components
13
  from . import auth
@@ -24,7 +25,7 @@ def get_full_status_update(
24
  profile: gr.OAuthProfile | None, oauth_token: gr.OAuthToken | None
25
  ):
26
  owner = auth.is_space_owner(profile, oauth_token)
27
- auth_status = "Authenticating..." # Initial state, not in schema
28
  is_on_space = cfg.SPACE_OWNER is not None
29
  hf_handle_val = ""
30
  hf_handle_interactive = not is_on_space
@@ -53,27 +54,42 @@ def get_full_status_update(
53
  "SELECT participant_id, hf_handle, email, partition_id FROM requests WHERE status = 'approved' ORDER BY timestamp DESC"
54
  ).fetchall()
55
 
56
- superlink_is_running = (
57
- processing.process_store["superlink"]
58
- and processing.process_store["superlink"].poll() is None
59
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  runner_is_running = (
61
  processing.process_store["runner"]
62
  and processing.process_store["runner"].poll() is None
63
  )
64
-
65
- # Hardcode status text as it's not in the schema
66
- superlink_status = "🟢 Running" if superlink_is_running else "🔴 Not Running"
67
  runner_status = "🟢 Running" if runner_is_running else "🔴 Not Running"
68
 
69
- # Hardcode button text as it's not in the schema
70
- if superlink_is_running:
71
- superlink_btn_update = gr.update(value="🛑 Stop Superlink", variant="stop")
72
- else:
73
- superlink_btn_update = gr.update(
74
- value="🚀 Start Superlink", variant="secondary"
75
- )
76
-
77
  if runner_is_running:
78
  runner_btn_update = gr.update(value="🛑 Stop Runner", variant="stop")
79
  else:
 
8
  from blossomtune_gradio import federation as fed
9
  from blossomtune_gradio import processing
10
  from blossomtune_gradio.settings import settings
11
+ from blossomtune_gradio import util
12
 
13
  from . import components
14
  from . import auth
 
25
  profile: gr.OAuthProfile | None, oauth_token: gr.OAuthToken | None
26
  ):
27
  owner = auth.is_space_owner(profile, oauth_token)
28
+ auth_status = "Authenticating..."
29
  is_on_space = cfg.SPACE_OWNER is not None
30
  hf_handle_val = ""
31
  hf_handle_interactive = not is_on_space
 
54
  "SELECT participant_id, hf_handle, email, partition_id FROM requests WHERE status = 'approved' ORDER BY timestamp DESC"
55
  ).fetchall()
56
 
57
+ # Superlink Status Logic
58
+ superlink_btn_update = gr.update() # Default empty update
59
+
60
+ if cfg.SUPERLINK_MODE == "internal":
61
+ superlink_is_running = (
62
+ processing.process_store["superlink"]
63
+ and processing.process_store["superlink"].poll() is None
64
+ )
65
+ superlink_status = "🟢 Running" if superlink_is_running else "🔴 Not Running"
66
+ if superlink_is_running:
67
+ superlink_btn_update = gr.update(
68
+ value="🛑 Stop Superlink", variant="stop", interactive=True
69
+ )
70
+ else:
71
+ superlink_btn_update = gr.update(
72
+ value="🚀 Start Superlink", variant="secondary", interactive=True
73
+ )
74
+
75
+ elif cfg.SUPERLINK_MODE == "external":
76
+ if not cfg.SUPERLINK_HOST:
77
+ superlink_status = "🔴 Not Configured"
78
+ else:
79
+ is_open = util.is_port_open(cfg.SUPERLINK_HOST, cfg.SUPERLINK_PORT)
80
+ superlink_status = "🟢 Running" if is_open else "🔴 Not Running"
81
+ # Disable the button in external mode
82
+ superlink_btn_update = gr.update(value="Managed Externally", interactive=False)
83
+ else:
84
+ superlink_status = "⚠️ Invalid Mode"
85
+ superlink_btn_update = gr.update(interactive=False)
86
+
87
  runner_is_running = (
88
  processing.process_store["runner"]
89
  and processing.process_store["runner"].poll() is None
90
  )
 
 
 
91
  runner_status = "🟢 Running" if runner_is_running else "🔴 Not Running"
92
 
 
 
 
 
 
 
 
 
93
  if runner_is_running:
94
  runner_btn_update = gr.update(value="🛑 Stop Runner", variant="stop")
95
  else:
blossomtune_gradio/util.py CHANGED
@@ -1,7 +1,32 @@
1
  import re
 
2
  import dns.resolver
3
 
4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  def validate_email(email_address: str) -> bool:
6
  """
7
  Validates an email address using regex for format and
 
1
  import re
2
+ import socket
3
  import dns.resolver
4
 
5
 
6
+ def is_port_open(host: str, port: int, timeout: float = 1.0) -> bool:
7
+ """
8
+ Checks if a TCP port is open on a given host.
9
+
10
+ Args:
11
+ host: The hostname or IP address to check.
12
+ port: The port number to check.
13
+ timeout: The connection timeout in seconds.
14
+
15
+ Returns:
16
+ True if the port is open and a connection can be established,
17
+ False otherwise.
18
+ """
19
+ try:
20
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
21
+ s.settimeout(timeout)
22
+ s.connect((host, port))
23
+ print(f"TCP check successful: Port {port} is open on {host}.")
24
+ return True
25
+ except (socket.timeout, ConnectionRefusedError, OSError) as e:
26
+ print(f"TCP check failed: Port {port} on {host} is not open. Error: {e}")
27
+ return False
28
+
29
+
30
  def validate_email(email_address: str) -> bool:
31
  """
32
  Validates an email address using regex for format and
data/cache/.gitkeep ADDED
File without changes
data/certs/.gitkeep ADDED
File without changes
data/db/.gitkeep ADDED
File without changes
data/keys/.gitkeep ADDED
File without changes
docker-compose.yaml ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ services:
2
+ gradio_app:
3
+ image: ethicalabs/blossomtune-orchestrator:latest
4
+ container_name: gradio_app
5
+ build:
6
+ context: .
7
+ dockerfile: Dockerfile
8
+ command: ["gradio_app"]
9
+ ports:
10
+ - "7860:7860" # Expose the Gradio port to the host machine
11
+ volumes:
12
+ - ./data/db:/data/db # Mount the database directory for persistence
13
+ - ./data/certs:/data/certs:ro # Mount TLS certificates (read-only)
14
+ - ./data/keys:/data/keys:rw # Mount authentication keys (read-write)
15
+ - ./data/cache:/root/.cache # Mount Hugging Face cache for models/datasets
16
+ depends_on:
17
+ - superlink # Optional: Ensures superlink starts before the UI
18
+ environment:
19
+ HF_TOKEN: ${HF_TOKEN}
20
+ SUPERLINK_MODE: external
21
+ SUPERLINK_HOST: host.docker.internal
22
+ superlink:
23
+ image: ethicalabs/blossomtune-orchestrator:latest
24
+ container_name: superlink
25
+ build:
26
+ context: .
27
+ dockerfile: Dockerfile
28
+ command: ["superlink"]
29
+ ports:
30
+ - "127.0.0.1:9092:9092" # Port for SuperNode connections
31
+ - "9093:9093" # Port for Flower CLI (e.g., flwr run)
32
+ volumes:
33
+ - ./data/certs:/data/certs:ro # Mount TLS certificates (read-only)
34
+ - ./data/keys:/data/keys:ro # Mount authentication keys (read-only)
35
+ - ./data/results:/data/results # Mount results directory for persisting run artifacts
36
+ mailhog:
37
+ image: mailhog/mailhog
38
+ container_name: mailhog
39
+ restart: always
40
+ volumes:
41
+ - ./mailhog.auth:/mailhog.auth:ro
42
+ - ./data:/data:rw
43
+ ports:
44
+ - "1025:1025"
45
+ - "8025:8025"
docker_entrypoint.sh CHANGED
@@ -8,6 +8,12 @@ env | grep -v -E "SECRET|KEY|TOKEN|PASSWORD|ACCESS"
8
  if [ "${1}" = "gradio_app" ]; then
9
  echo "Running Gradio App..."
10
  exec python3 -m blossomtune_gradio
 
 
 
 
 
 
11
  else
12
  exec "$@"
13
  fi
 
8
  if [ "${1}" = "gradio_app" ]; then
9
  echo "Running Gradio App..."
10
  exec python3 -m blossomtune_gradio
11
+ elif [ "${1}" = "superlink" ]; then
12
+ echo "Running Superlink..."
13
+ exec flower-superlink \
14
+ --ssl-ca-certfile /data/certs/ca.crt \
15
+ --ssl-certfile /data/certs/server.pem \
16
+ --ssl-keyfile /data/certs/server.key
17
  else
18
  exec "$@"
19
  fi
pyproject.toml CHANGED
@@ -68,6 +68,7 @@ convention = "google" # Accepts: "google", "numpy", or "pep257".
68
 
69
  [dependency-groups]
70
  dev = [
 
71
  "dnspython>=2.8.0",
72
  "pytest>=8.4.1",
73
  "pytest-mock>=3.15.1",
 
68
 
69
  [dependency-groups]
70
  dev = [
71
+ "cryptography>=44.0.3",
72
  "dnspython>=2.8.0",
73
  "pytest>=8.4.1",
74
  "pytest-mock>=3.15.1",
tests/test_auth_keys.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import csv
3
+ import stat
4
+ import pytest
5
+ from cryptography.hazmat.primitives import serialization
6
+ from cryptography.hazmat.primitives.asymmetric import ec
7
+
8
+ from blossomtune_gradio.auth_keys import AuthKeyGenerator, rebuild_authorized_keys_csv
9
+
10
+
11
+ @pytest.fixture
12
+ def key_generator(tmp_path):
13
+ """Fixture to create an AuthKeyGenerator instance in a temporary directory."""
14
+ key_dir = tmp_path / "auth_keys"
15
+ return AuthKeyGenerator(key_dir=str(key_dir))
16
+
17
+
18
+ class TestAuthKeyGenerator:
19
+ """Test suite for the AuthKeyGenerator class."""
20
+
21
+ def test_init_creates_directory(self, tmp_path):
22
+ """Verify that the key directory is created on initialization."""
23
+ key_dir = tmp_path / "new_keys"
24
+ assert not os.path.exists(key_dir)
25
+ AuthKeyGenerator(key_dir=str(key_dir))
26
+ assert os.path.exists(key_dir)
27
+
28
+ def test_generate_participant_keys_creates_files_and_returns_data(
29
+ self, key_generator
30
+ ):
31
+ """
32
+ Verify that the main method generates all expected files and returns
33
+ the correct data tuple.
34
+ """
35
+ participant_id = "participant_01"
36
+ priv_path, pub_path, pub_pem = key_generator.generate_participant_keys(
37
+ participant_id
38
+ )
39
+
40
+ # 1. Check if files exist
41
+ assert os.path.exists(priv_path)
42
+ assert os.path.exists(pub_path)
43
+ assert priv_path == os.path.join(key_generator.key_dir, f"{participant_id}.key")
44
+ assert pub_path == os.path.join(key_generator.key_dir, f"{participant_id}.pub")
45
+
46
+ # 2. Check private key file permissions (security check)
47
+ # In non-Windows environments, check for 600 permissions.
48
+ if os.name != "nt":
49
+ file_mode = stat.S_IMODE(os.stat(priv_path).st_mode)
50
+ assert file_mode == 0o600
51
+
52
+ # 3. Verify key formats and consistency
53
+ with open(priv_path, "rb") as f:
54
+ private_key = serialization.load_pem_private_key(f.read(), password=None)
55
+ with open(pub_path, "rb") as f:
56
+ public_key = serialization.load_pem_public_key(f.read())
57
+ _ = f.read() # Read again to get bytes
58
+
59
+ assert isinstance(private_key, ec.EllipticCurvePrivateKey)
60
+ assert isinstance(public_key, ec.EllipticCurvePublicKey)
61
+
62
+ # Check that the returned PEM string matches the public key
63
+ generated_public_key = private_key.public_key()
64
+ pem_from_private = generated_public_key.public_bytes(
65
+ encoding=serialization.Encoding.PEM,
66
+ format=serialization.PublicFormat.SubjectPublicKeyInfo,
67
+ ).decode("utf-8")
68
+ assert pub_pem == pem_from_private
69
+
70
+
71
+ class TestRebuildCSV:
72
+ """Test suite for the rebuild_authorized_keys_csv function."""
73
+
74
+ def test_rebuild_csv_creates_file_with_header_for_empty_list(self, tmp_path):
75
+ """Verify a CSV with only a header is created for an empty participant list."""
76
+ key_dir = tmp_path / "csv_test"
77
+ os.makedirs(key_dir)
78
+ csv_path = os.path.join(key_dir, "authorized_supernodes.csv")
79
+
80
+ rebuild_authorized_keys_csv(key_dir, [])
81
+
82
+ assert os.path.exists(csv_path)
83
+ with open(csv_path, "r") as f:
84
+ reader = csv.reader(f)
85
+ header = next(reader)
86
+ assert header == ["participant_id", "public_key_pem"]
87
+ # Check that there are no more rows
88
+ with pytest.raises(StopIteration):
89
+ next(reader)
90
+
91
+ def test_rebuild_csv_writes_correct_data(self, tmp_path):
92
+ """Verify the CSV is created with the correct participant data."""
93
+ key_dir = tmp_path / "csv_test"
94
+ os.makedirs(key_dir)
95
+ csv_path = os.path.join(key_dir, "authorized_supernodes.csv")
96
+
97
+ participants = [
98
+ ("p1", "---BEGIN PUBLIC KEY---...p1...---END PUBLIC KEY---"),
99
+ ("p2", "---BEGIN PUBLIC KEY---...p2...---END PUBLIC KEY---"),
100
+ ]
101
+ rebuild_authorized_keys_csv(key_dir, participants)
102
+
103
+ with open(csv_path, "r") as f:
104
+ reader = csv.reader(f)
105
+ header = next(reader)
106
+ row1 = next(reader)
107
+ row2 = next(reader)
108
+ assert header == ["participant_id", "public_key_pem"]
109
+ assert row1 == list(participants[0])
110
+ assert row2 == list(participants[1])
111
+
112
+ def test_rebuild_csv_overwrites_existing_file(self, tmp_path):
113
+ """Verify that an existing CSV file is correctly overwritten."""
114
+ key_dir = tmp_path / "csv_test"
115
+ os.makedirs(key_dir)
116
+ csv_path = os.path.join(key_dir, "authorized_supernodes.csv")
117
+
118
+ # First run with initial data
119
+ initial_participants = [("old_p1", "old_key_1")]
120
+ rebuild_authorized_keys_csv(key_dir, initial_participants)
121
+
122
+ # Second run with new data
123
+ new_participants = [("new_p1", "new_key_1"), ("new_p2", "new_key_2")]
124
+ rebuild_authorized_keys_csv(key_dir, new_participants)
125
+
126
+ with open(csv_path, "r") as f:
127
+ reader = csv.reader(f)
128
+ _ = next(reader)
129
+ rows = list(reader)
130
+
131
+ assert len(rows) == 2
132
+ assert rows[0] == list(new_participants[0])
133
+ assert rows[1] == list(new_participants[1])
tests/test_blossomfile.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import zipfile
4
+ import pytest
5
+
6
+ from blossomtune_gradio.blossomfile import create_blossomfile
7
+
8
+
9
+ @pytest.fixture
10
+ def dummy_credential_files(tmp_path):
11
+ """
12
+ Pytest fixture to create dummy credential files in a temporary directory.
13
+ Returns the paths to the created files.
14
+ """
15
+ creds_dir = tmp_path / "creds"
16
+ os.makedirs(creds_dir)
17
+
18
+ ca_cert_path = creds_dir / "ca.crt"
19
+ auth_key_path = creds_dir / "auth.key"
20
+ auth_pub_path = creds_dir / "auth.pub"
21
+
22
+ ca_cert_path.write_text("---BEGIN CERTIFICATE---")
23
+ auth_key_path.write_text("---BEGIN EC PRIVATE KEY---")
24
+ auth_pub_path.write_text("---BEGIN PUBLIC KEY---")
25
+
26
+ return {
27
+ "ca_cert_path": str(ca_cert_path),
28
+ "auth_key_path": str(auth_key_path),
29
+ "auth_pub_path": str(auth_pub_path),
30
+ }
31
+
32
+
33
+ def test_create_blossomfile_success(tmp_path, dummy_credential_files):
34
+ """
35
+ Tests the successful creation of a .blossomfile, verifying its contents.
36
+ """
37
+ output_dir = tmp_path / "output"
38
+ participant_id = "participant_abc"
39
+
40
+ blossomfile_path = create_blossomfile(
41
+ participant_id=participant_id,
42
+ output_dir=str(output_dir),
43
+ ca_cert_path=dummy_credential_files["ca_cert_path"],
44
+ auth_key_path=dummy_credential_files["auth_key_path"],
45
+ auth_pub_path=dummy_credential_files["auth_pub_path"],
46
+ superlink_address="blossomtune-test.ethicalabs.ai:9092",
47
+ partition_id=5,
48
+ num_partitions=10,
49
+ )
50
+
51
+ # 1. Verify the file was created at the correct path
52
+ assert os.path.exists(blossomfile_path)
53
+ assert blossomfile_path == str(output_dir / f"{participant_id}.blossomfile")
54
+
55
+ # 2. Verify the contents of the zip archive
56
+ with zipfile.ZipFile(blossomfile_path, "r") as zf:
57
+ # Check that all expected files are present
58
+ namelist = zf.namelist()
59
+ assert "blossom.json" in namelist
60
+ assert "ca.crt" in namelist
61
+ assert "auth.key" in namelist
62
+ assert "auth.pub" in namelist
63
+ assert len(namelist) == 4
64
+
65
+ # Check the content of blossom.json
66
+ with zf.open("blossom.json") as f:
67
+ config_data = json.load(f)
68
+ assert (
69
+ config_data["superlink_address"]
70
+ == "blossomtune-test.ethicalabs.ai:9092"
71
+ )
72
+ assert config_data["node_config"]["partition-id"] == 5
73
+ assert config_data["node_config"]["num-partitions"] == 10
74
+
75
+ # Check the content of the credential files
76
+ assert zf.read("ca.crt").decode("utf-8") == "---BEGIN CERTIFICATE---"
77
+ assert zf.read("auth.key").decode("utf-8") == "---BEGIN EC PRIVATE KEY---"
78
+ assert zf.read("auth.pub").decode("utf-8") == "---BEGIN PUBLIC KEY---"
79
+
80
+
81
+ def test_create_blossomfile_missing_input_file(tmp_path):
82
+ """
83
+ Tests that the function raises FileNotFoundError if a required credential
84
+ file is missing and cleans up the partial archive.
85
+ """
86
+ output_dir = tmp_path / "output"
87
+ participant_id = "participant_xyz"
88
+ missing_file_path = tmp_path / "creds" / "non_existent.key"
89
+ blossomfile_path = output_dir / f"{participant_id}.blossomfile"
90
+
91
+ with pytest.raises(FileNotFoundError, match="Required credential file not found"):
92
+ create_blossomfile(
93
+ participant_id=participant_id,
94
+ output_dir=str(output_dir),
95
+ ca_cert_path=str(missing_file_path), # Pass a path that doesn't exist
96
+ auth_key_path="dummy",
97
+ auth_pub_path="dummy",
98
+ superlink_address="test:9092",
99
+ partition_id=1,
100
+ num_partitions=2,
101
+ )
102
+
103
+ # Verify that the partially created blossomfile was removed
104
+ assert not os.path.exists(blossomfile_path)
tests/test_tls.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pytest
3
+ from cryptography import x509
4
+
5
+ from blossomtune_gradio.tls import TLSGenerator
6
+
7
+
8
+ @pytest.fixture
9
+ def tls_generator(tmp_path):
10
+ """Fixture to create a TLSGenerator instance in a temporary directory."""
11
+ cert_dir = tmp_path / "certs"
12
+ return TLSGenerator(cert_dir=str(cert_dir))
13
+
14
+
15
+ class TestTLSGenerator:
16
+ """Test suite for the TLSGenerator class."""
17
+
18
+ def test_init_creates_directory(self, tmp_path):
19
+ """Verify that the certificate directory is created on initialization."""
20
+ cert_dir = tmp_path / "new_certs"
21
+ assert not os.path.exists(cert_dir)
22
+ TLSGenerator(cert_dir=str(cert_dir))
23
+ assert os.path.exists(cert_dir)
24
+
25
+ def test_create_ca(self, tls_generator):
26
+ """Test the creation of a self-signed Certificate Authority."""
27
+ ca_key, ca_cert = tls_generator.create_ca()
28
+
29
+ assert os.path.exists(os.path.join(tls_generator.cert_dir, "ca.key"))
30
+ assert os.path.exists(os.path.join(tls_generator.cert_dir, "ca.crt"))
31
+ assert ca_cert.issuer == ca_cert.subject
32
+ assert (
33
+ ca_cert.extensions.get_extension_for_class(x509.BasicConstraints).value.ca
34
+ is True
35
+ )
36
+
37
+ def test_generate_server_certificate_with_new_ca(self, tls_generator):
38
+ """
39
+ Test generating a server certificate, which should also create a new CA
40
+ and the combined server.pem file.
41
+ """
42
+ common_name = "test.local"
43
+ sans = ["test.local", "192.168.1.10"]
44
+ tls_generator.generate_server_certificate(common_name=common_name, sans=sans)
45
+
46
+ # Check for all expected files
47
+ assert os.path.exists(os.path.join(tls_generator.cert_dir, "ca.key"))
48
+ assert os.path.exists(os.path.join(tls_generator.cert_dir, "ca.crt"))
49
+ assert os.path.exists(os.path.join(tls_generator.cert_dir, "server.key"))
50
+ assert os.path.exists(os.path.join(tls_generator.cert_dir, "server.crt"))
51
+ assert os.path.exists(os.path.join(tls_generator.cert_dir, "server.pem"))
52
+
53
+ # Load certs to verify issuer relationship
54
+ with open(os.path.join(tls_generator.cert_dir, "ca.crt"), "rb") as f:
55
+ ca_cert = x509.load_pem_x509_certificate(f.read())
56
+ with open(os.path.join(tls_generator.cert_dir, "server.crt"), "rb") as f:
57
+ server_cert = x509.load_pem_x509_certificate(f.read())
58
+
59
+ assert server_cert.issuer == ca_cert.subject
60
+
61
+ # Verify PEM content
62
+ with open(os.path.join(tls_generator.cert_dir, "server.pem"), "r") as f:
63
+ pem_content = f.read()
64
+ assert "-----BEGIN CERTIFICATE-----" in pem_content
65
+ assert "-----BEGIN RSA PRIVATE KEY-----" in pem_content
66
+
67
+ def test_generate_server_certificate_with_existing_ca(self, tls_generator):
68
+ """Test generating a server certificate using a pre-existing CA."""
69
+ tls_generator.create_ca()
70
+ ca_key_path = os.path.join(tls_generator.cert_dir, "ca.key")
71
+ ca_cert_path = os.path.join(tls_generator.cert_dir, "ca.crt")
72
+
73
+ server_gen_dir = os.path.join(
74
+ os.path.dirname(tls_generator.cert_dir), "server_certs"
75
+ )
76
+ server_generator = TLSGenerator(cert_dir=server_gen_dir)
77
+
78
+ server_generator.generate_server_certificate(
79
+ common_name="prod.local", ca_key_path=ca_key_path, ca_cert_path=ca_cert_path
80
+ )
81
+
82
+ # Check that server files were created in the new directory
83
+ assert os.path.exists(os.path.join(server_gen_dir, "server.key"))
84
+ assert os.path.exists(os.path.join(server_gen_dir, "server.crt"))
85
+ assert os.path.exists(os.path.join(server_gen_dir, "server.pem"))
86
+ # Check that a *new* CA was NOT created in the server directory
87
+ assert not os.path.exists(os.path.join(server_gen_dir, "ca.key"))
88
+
89
+ # Verify PEM content
90
+ with open(os.path.join(server_gen_dir, "server.pem"), "r") as f:
91
+ pem_content = f.read()
92
+ assert "-----BEGIN CERTIFICATE-----" in pem_content
93
+ assert "-----BEGIN RSA PRIVATE KEY-----" in pem_content
tests/test_util.py CHANGED
@@ -1,8 +1,43 @@
1
  import pytest
 
2
  import dns.resolver
3
  from unittest.mock import MagicMock
4
 
5
- from blossomtune_gradio.util import validate_email, strtobool
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
 
8
  def test_validate_email_valid(monkeypatch):
 
1
  import pytest
2
+ import socket
3
  import dns.resolver
4
  from unittest.mock import MagicMock
5
 
6
+ from blossomtune_gradio.util import is_port_open, validate_email, strtobool
7
+
8
+
9
+ def test_is_port_open_success(mocker):
10
+ """
11
+ Tests the case where the port is open and the connection succeeds.
12
+ """
13
+ mock_socket = mocker.patch("blossomtune_gradio.util.socket.socket")
14
+ mock_socket.return_value.__enter__.return_value.connect.return_value = None
15
+
16
+ result = is_port_open("testhost", 1234)
17
+
18
+ assert result is True
19
+ # Verify that the socket was created and connect was called with the correct args
20
+ mock_socket.return_value.__enter__.return_value.settimeout.assert_called_once_with(
21
+ 1.0
22
+ )
23
+ mock_socket.return_value.__enter__.return_value.connect.assert_called_once_with(
24
+ ("testhost", 1234)
25
+ )
26
+
27
+
28
+ @pytest.mark.parametrize("exception", [ConnectionRefusedError, socket.timeout, OSError])
29
+ def test_is_port_open_failures(mocker, exception):
30
+ """
31
+ Tests various failure scenarios where the port is not open.
32
+ - ConnectionRefusedError: The host actively refuses the connection.
33
+ - socket.timeout: The connection attempt times out.
34
+ - OSError: A generic network error occurs.
35
+ """
36
+ mock_socket = mocker.patch("blossomtune_gradio.util.socket.socket")
37
+ mock_socket.return_value.__enter__.return_value.connect.side_effect = exception
38
+
39
+ result = is_port_open("testhost", 1234)
40
+ assert result is False
41
 
42
 
43
  def test_validate_email_valid(monkeypatch):
uv.lock CHANGED
@@ -241,6 +241,7 @@ dependencies = [
241
 
242
  [package.dev-dependencies]
243
  dev = [
 
244
  { name = "dnspython" },
245
  { name = "pytest" },
246
  { name = "pytest-mock" },
@@ -262,6 +263,7 @@ requires-dist = [
262
 
263
  [package.metadata.requires-dev]
264
  dev = [
 
265
  { name = "dnspython", specifier = ">=2.8.0" },
266
  { name = "pytest", specifier = ">=8.4.1" },
267
  { name = "pytest-mock", specifier = ">=3.15.1" },
 
241
 
242
  [package.dev-dependencies]
243
  dev = [
244
+ { name = "cryptography" },
245
  { name = "dnspython" },
246
  { name = "pytest" },
247
  { name = "pytest-mock" },
 
263
 
264
  [package.metadata.requires-dev]
265
  dev = [
266
+ { name = "cryptography", specifier = ">=44.0.3" },
267
  { name = "dnspython", specifier = ">=2.8.0" },
268
  { name = "pytest", specifier = ">=8.4.1" },
269
  { name = "pytest-mock", specifier = ">=3.15.1" },