Spaces:
Sleeping
Sleeping
TLS certs and external superlink support (plus initial work on authentication keys and .blossomfile) (#3)
Browse files* Adds script for TLS certs generation.
* Adds support for SUPERLINK_MODE (internal or external)
* Updates UI to support external superlink.
* Adds docker-compose.yaml file for self-signed local certs testing.
* Adds more tests
* Updates README.md
* Adds .blossomfile spec: a single file for the participant, with the configuration needed to connect to a federation. Work in progress!
- README.md +44 -0
- blossomtune_gradio/auth_keys.py +120 -0
- blossomtune_gradio/blossomfile.py +87 -0
- blossomtune_gradio/config.py +18 -2
- blossomtune_gradio/generate_tls.py +90 -0
- blossomtune_gradio/processing.py +5 -0
- blossomtune_gradio/tls.py +139 -0
- blossomtune_gradio/ui/callbacks.py +32 -16
- blossomtune_gradio/util.py +25 -0
- data/cache/.gitkeep +0 -0
- data/certs/.gitkeep +0 -0
- data/db/.gitkeep +0 -0
- data/keys/.gitkeep +0 -0
- docker-compose.yaml +45 -0
- docker_entrypoint.sh +6 -0
- pyproject.toml +1 -0
- tests/test_auth_keys.py +133 -0
- tests/test_blossomfile.py +104 -0
- tests/test_tls.py +93 -0
- tests/test_util.py +36 -1
- uv.lock +2 -0
README.md
CHANGED
|
@@ -77,6 +77,50 @@ python -m blossomtune_gradio
|
|
| 77 |
|
| 78 |
The application will be accessible via a local URL provided by Gradio.
|
| 79 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
## Usage Guide
|
| 81 |
|
| 82 |
### For Participants
|
|
|
|
| 77 |
|
| 78 |
The application will be accessible via a local URL provided by Gradio.
|
| 79 |
|
| 80 |
+
## Generating Self-Signed Certificates for Local Development (Docker)
|
| 81 |
+
|
| 82 |
+
When running the application with `docker-compose`, the `superlink` service requires TLS certificates to enable secure connections.
|
| 83 |
+
|
| 84 |
+
For local development, you can generate a self-signed Certificate Authority (CA) and a `localhost` certificate using the provided script.
|
| 85 |
+
|
| 86 |
+
**Step 1: Run the Certificate Generator**
|
| 87 |
+
|
| 88 |
+
Execute the interactive TLS generation script located in the `blossomtune_gradio` directory:
|
| 89 |
+
|
| 90 |
+
```bash
|
| 91 |
+
python3 -m blossomtune_gradio.generate_tls
|
| 92 |
+
```
|
| 93 |
+
|
| 94 |
+
**Step 2: Choose the Development Option**
|
| 95 |
+
|
| 96 |
+
When prompted, select option **1** to generate a self-signed certificate for `localhost`.
|
| 97 |
+
|
| 98 |
+
```text
|
| 99 |
+
===== BlossomTune TLS Certificate Generator =====
|
| 100 |
+
Select an option:
|
| 101 |
+
1. Generate a self-signed 'localhost' certificate (for Development)
|
| 102 |
+
2. Generate a server certificate using the main CA (for Production)
|
| 103 |
+
3. Exit
|
| 104 |
+
Enter your choice [1]: 1
|
| 105 |
+
```
|
| 106 |
+
|
| 107 |
+
The script will create a new directory named `certificates_localhost` containing the generated CA (`ca.crt`) and the server certificate files (`server.key`, `server.crt`, `server.pem`).
|
| 108 |
+
|
| 109 |
+
**Step 3: Copy Certificates to the Data Directory**
|
| 110 |
+
|
| 111 |
+
The `docker-compose.yml` file is configured to mount a local `./data/certs` directory into the `superlink` container. You must copy the essential certificate files into this location:
|
| 112 |
+
|
| 113 |
+
```bash
|
| 114 |
+
cp certificates_localhost/ca.crt ./data/certs/
|
| 115 |
+
cp certificates_localhost/server.key ./data/certs/
|
| 116 |
+
cp certificates_localhost/server.pem ./data/certs/
|
| 117 |
+
```
|
| 118 |
+
|
| 119 |
+
Once these files are in place, you can start the services using `docker compose up`.
|
| 120 |
+
|
| 121 |
+
The Superlink will automatically find and use these certificates to secure its connections.
|
| 122 |
+
|
| 123 |
+
|
| 124 |
## Usage Guide
|
| 125 |
|
| 126 |
### For Participants
|
blossomtune_gradio/auth_keys.py
ADDED
|
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import csv
|
| 3 |
+
import logging
|
| 4 |
+
from typing import List, Tuple
|
| 5 |
+
from cryptography.hazmat.primitives.asymmetric import ec
|
| 6 |
+
from cryptography.hazmat.primitives import serialization
|
| 7 |
+
|
| 8 |
+
log = logging.getLogger(__name__)
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def rebuild_authorized_keys_csv(
|
| 12 |
+
key_dir: str, authorized_participants: List[Tuple[str, str]]
|
| 13 |
+
):
|
| 14 |
+
"""
|
| 15 |
+
Overwrites the public key CSV with a fresh list from a trusted source.
|
| 16 |
+
|
| 17 |
+
This function should be called before starting the Flower SuperLink to ensure
|
| 18 |
+
the list of authorized nodes is always in sync with the database.
|
| 19 |
+
|
| 20 |
+
Args:
|
| 21 |
+
key_dir (str): The directory where the CSV will be stored.
|
| 22 |
+
authorized_participants (List[Tuple[str, str]]): A list of tuples,
|
| 23 |
+
where each tuple contains (participant_id, public_key_pem).
|
| 24 |
+
"""
|
| 25 |
+
csv_path = os.path.join(key_dir, "authorized_supernodes.csv")
|
| 26 |
+
log.info(f"Rebuilding authorized keys CSV at: {csv_path}")
|
| 27 |
+
|
| 28 |
+
with open(csv_path, "w", newline="") as f:
|
| 29 |
+
writer = csv.writer(f)
|
| 30 |
+
writer.writerow(["participant_id", "public_key_pem"])
|
| 31 |
+
for participant_id, public_key_pem in authorized_participants:
|
| 32 |
+
writer.writerow([participant_id, public_key_pem])
|
| 33 |
+
|
| 34 |
+
log.info(
|
| 35 |
+
f"Successfully rebuilt {csv_path} with {len(authorized_participants)} keys."
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class AuthKeyGenerator:
|
| 40 |
+
"""
|
| 41 |
+
Handles the generation of Elliptic Curve (EC) key pairs for Flower
|
| 42 |
+
SuperNode authentication.
|
| 43 |
+
"""
|
| 44 |
+
|
| 45 |
+
def __init__(self, key_dir: str = "keys"):
|
| 46 |
+
"""
|
| 47 |
+
Initializes the generator and ensures the key directory exists.
|
| 48 |
+
|
| 49 |
+
Args:
|
| 50 |
+
key_dir (str): The directory where key files will be stored.
|
| 51 |
+
"""
|
| 52 |
+
self.key_dir = key_dir
|
| 53 |
+
os.makedirs(self.key_dir, exist_ok=True)
|
| 54 |
+
log.info(f"Authentication key directory set to: {self.key_dir}")
|
| 55 |
+
|
| 56 |
+
def _generate_key_pair(self) -> ec.EllipticCurvePrivateKey:
|
| 57 |
+
"""Generates a single EC private key using the SECP384R1 curve."""
|
| 58 |
+
return ec.generate_private_key(ec.SECP384R1())
|
| 59 |
+
|
| 60 |
+
def _save_private_key(
|
| 61 |
+
self, private_key: ec.EllipticCurvePrivateKey, participant_id: str
|
| 62 |
+
) -> str:
|
| 63 |
+
"""Saves the private key to a file with secure permissions."""
|
| 64 |
+
priv_key_path = os.path.join(self.key_dir, f"{participant_id}.key")
|
| 65 |
+
with open(priv_key_path, "wb") as f:
|
| 66 |
+
f.write(
|
| 67 |
+
private_key.private_bytes(
|
| 68 |
+
encoding=serialization.Encoding.PEM,
|
| 69 |
+
format=serialization.PrivateFormat.PKCS8,
|
| 70 |
+
encryption_algorithm=serialization.NoEncryption(),
|
| 71 |
+
)
|
| 72 |
+
)
|
| 73 |
+
# Set file permissions to read/write for owner only (600)
|
| 74 |
+
os.chmod(priv_key_path, 0o600)
|
| 75 |
+
log.info(f"Private key for {participant_id} saved securely to {priv_key_path}")
|
| 76 |
+
return priv_key_path
|
| 77 |
+
|
| 78 |
+
def _save_public_key_file(
|
| 79 |
+
self, private_key: ec.EllipticCurvePrivateKey, participant_id: str
|
| 80 |
+
) -> str:
|
| 81 |
+
"""Saves the public key to a .pub file."""
|
| 82 |
+
public_key = private_key.public_key()
|
| 83 |
+
pub_key_path = os.path.join(self.key_dir, f"{participant_id}.pub")
|
| 84 |
+
with open(pub_key_path, "wb") as f:
|
| 85 |
+
f.write(
|
| 86 |
+
public_key.public_bytes(
|
| 87 |
+
encoding=serialization.Encoding.PEM,
|
| 88 |
+
format=serialization.PublicFormat.SubjectPublicKeyInfo,
|
| 89 |
+
)
|
| 90 |
+
)
|
| 91 |
+
log.info(f"Public key for {participant_id} saved to {pub_key_path}")
|
| 92 |
+
return pub_key_path
|
| 93 |
+
|
| 94 |
+
def generate_participant_keys(self, participant_id: str) -> Tuple[str, str, str]:
|
| 95 |
+
"""
|
| 96 |
+
Generates and saves a new EC key pair for a participant.
|
| 97 |
+
|
| 98 |
+
This is the main public method to call when a new participant is approved.
|
| 99 |
+
|
| 100 |
+
Args:
|
| 101 |
+
participant_id (str): A unique identifier for the participant.
|
| 102 |
+
|
| 103 |
+
Returns:
|
| 104 |
+
A tuple containing:
|
| 105 |
+
- The file path to the generated private key.
|
| 106 |
+
- The file path to the generated public key.
|
| 107 |
+
- The public key as a PEM-encoded string (for database storage).
|
| 108 |
+
"""
|
| 109 |
+
private_key = self._generate_key_pair()
|
| 110 |
+
public_key = private_key.public_key()
|
| 111 |
+
|
| 112 |
+
priv_key_path = self._save_private_key(private_key, participant_id)
|
| 113 |
+
pub_key_path = self._save_public_key_file(private_key, participant_id)
|
| 114 |
+
|
| 115 |
+
public_key_pem = public_key.public_bytes(
|
| 116 |
+
encoding=serialization.Encoding.PEM,
|
| 117 |
+
format=serialization.PublicFormat.SubjectPublicKeyInfo,
|
| 118 |
+
).decode("utf-8")
|
| 119 |
+
|
| 120 |
+
return (priv_key_path, pub_key_path, public_key_pem)
|
blossomtune_gradio/blossomfile.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import zipfile
|
| 4 |
+
import logging
|
| 5 |
+
from typing import Dict, Any
|
| 6 |
+
|
| 7 |
+
# Configure logging for the module
|
| 8 |
+
log = logging.getLogger(__name__)
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def create_blossomfile(
|
| 12 |
+
participant_id: str,
|
| 13 |
+
output_dir: str,
|
| 14 |
+
ca_cert_path: str,
|
| 15 |
+
auth_key_path: str,
|
| 16 |
+
auth_pub_path: str,
|
| 17 |
+
superlink_address: str,
|
| 18 |
+
partition_id: int,
|
| 19 |
+
num_partitions: int,
|
| 20 |
+
) -> str:
|
| 21 |
+
"""
|
| 22 |
+
Generates a .blossomfile for a participant.
|
| 23 |
+
|
| 24 |
+
This function packages all necessary credentials and configuration into a
|
| 25 |
+
single, portable .zip archive that a participant can use to easily
|
| 26 |
+
connect to the federation.
|
| 27 |
+
|
| 28 |
+
Args:
|
| 29 |
+
participant_id: The unique identifier for the participant.
|
| 30 |
+
output_dir: The directory where the final .blossomfile will be saved.
|
| 31 |
+
ca_cert_path: Path to the CA public certificate (`ca.crt`).
|
| 32 |
+
auth_key_path: Path to the participant's private EC key (`auth.key`).
|
| 33 |
+
auth_pub_path: Path to the participant's public EC key (`auth.pub`).
|
| 34 |
+
superlink_address: The public address of the Flower SuperLink.
|
| 35 |
+
partition_id: The data partition ID assigned to the participant.
|
| 36 |
+
num_partitions: The total number of partitions in the federation.
|
| 37 |
+
|
| 38 |
+
Returns:
|
| 39 |
+
The full path to the generated .blossomfile.
|
| 40 |
+
"""
|
| 41 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 42 |
+
blossomfile_path = os.path.join(output_dir, f"{participant_id}.blossomfile")
|
| 43 |
+
log.info(f"Creating Blossomfile for {participant_id} at {blossomfile_path}")
|
| 44 |
+
|
| 45 |
+
# 1. Create the blossom.json configuration data
|
| 46 |
+
blossom_config: Dict[str, Any] = {
|
| 47 |
+
"superlink_address": superlink_address,
|
| 48 |
+
"node_config": {
|
| 49 |
+
"partition-id": partition_id,
|
| 50 |
+
"num-partitions": num_partitions,
|
| 51 |
+
},
|
| 52 |
+
}
|
| 53 |
+
|
| 54 |
+
# 2. Define the files to be included in the archive
|
| 55 |
+
files_to_add = {
|
| 56 |
+
ca_cert_path: "ca.crt",
|
| 57 |
+
auth_key_path: "auth.key",
|
| 58 |
+
auth_pub_path: "auth.pub",
|
| 59 |
+
}
|
| 60 |
+
|
| 61 |
+
# 3. Create the zip archive
|
| 62 |
+
try:
|
| 63 |
+
with zipfile.ZipFile(blossomfile_path, "w", zipfile.ZIP_DEFLATED) as zf:
|
| 64 |
+
# Add the configuration file
|
| 65 |
+
zf.writestr("blossom.json", json.dumps(blossom_config, indent=2))
|
| 66 |
+
log.info("Added blossom.json to archive.")
|
| 67 |
+
|
| 68 |
+
# Add the certificate and key files
|
| 69 |
+
for src_path, arc_name in files_to_add.items():
|
| 70 |
+
if os.path.exists(src_path):
|
| 71 |
+
zf.write(src_path, arcname=arc_name)
|
| 72 |
+
log.info(f"Added {arc_name} to archive from {src_path}.")
|
| 73 |
+
else:
|
| 74 |
+
log.error(f"Credential file not found: {src_path}. Aborting.")
|
| 75 |
+
raise FileNotFoundError(
|
| 76 |
+
f"Required credential file not found: {src_path}"
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
except Exception as e:
|
| 80 |
+
log.critical(f"Failed to create Blossomfile: {e}")
|
| 81 |
+
# Clean up partially created file on failure
|
| 82 |
+
if os.path.exists(blossomfile_path):
|
| 83 |
+
os.remove(blossomfile_path)
|
| 84 |
+
raise
|
| 85 |
+
|
| 86 |
+
log.info(f"Successfully created Blossomfile: {blossomfile_path}")
|
| 87 |
+
return blossomfile_path
|
blossomtune_gradio/config.py
CHANGED
|
@@ -10,7 +10,9 @@ SPACE_ID = os.getenv("SPACE_ID", "ethicalabs/BlossomTune-Orchestrator")
|
|
| 10 |
SPACE_OWNER = os.getenv("SPACE_OWNER", SPACE_ID.split("/")[0] if SPACE_ID else None)
|
| 11 |
|
| 12 |
# Use persistent storage if available
|
| 13 |
-
DB_PATH =
|
|
|
|
|
|
|
| 14 |
MAX_NUM_NODES = int(os.getenv("MAX_NUM_NODES", "20"))
|
| 15 |
SMTP_SENDER = os.getenv("SMTP_SENDER", "hello@ethicalabs.ai")
|
| 16 |
SMTP_SERVER = os.getenv("SMTP_SERVER", "localhost")
|
|
@@ -19,5 +21,19 @@ SMTP_REQUIRE_TLS = util.strtobool(os.getenv("SMTP_REQUIRE_TLS", "false"))
|
|
| 19 |
SMTP_USER = os.getenv("SMTP_USER", "")
|
| 20 |
SMTP_PASSWORD = os.getenv("SMTP_PASSWORD", "")
|
| 21 |
EMAIL_PROVIDER = os.getenv("EMAIL_PROVIDER", "smtp")
|
| 22 |
-
SUPERLINK_HOST = os.getenv("SUPERLINK_HOST", "")
|
| 23 |
SUPERLINK_PORT = int(os.getenv("SUPERLINK_PORT", 9092))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
SPACE_OWNER = os.getenv("SPACE_OWNER", SPACE_ID.split("/")[0] if SPACE_ID else None)
|
| 11 |
|
| 12 |
# Use persistent storage if available
|
| 13 |
+
DB_PATH = (
|
| 14 |
+
"/data/db/federation.db" if os.path.isdir("/data/db") else "./data/db/federation.db"
|
| 15 |
+
)
|
| 16 |
MAX_NUM_NODES = int(os.getenv("MAX_NUM_NODES", "20"))
|
| 17 |
SMTP_SENDER = os.getenv("SMTP_SENDER", "hello@ethicalabs.ai")
|
| 18 |
SMTP_SERVER = os.getenv("SMTP_SERVER", "localhost")
|
|
|
|
| 21 |
SMTP_USER = os.getenv("SMTP_USER", "")
|
| 22 |
SMTP_PASSWORD = os.getenv("SMTP_PASSWORD", "")
|
| 23 |
EMAIL_PROVIDER = os.getenv("EMAIL_PROVIDER", "smtp")
|
| 24 |
+
SUPERLINK_HOST = os.getenv("SUPERLINK_HOST", "127.0.0.1:9092")
|
| 25 |
SUPERLINK_PORT = int(os.getenv("SUPERLINK_PORT", 9092))
|
| 26 |
+
SUPERLINK_MODE = os.getenv("SUPERLINK_MODE", "internal").lower() # Or external
|
| 27 |
+
|
| 28 |
+
# TLS root cert path. For production only.
|
| 29 |
+
TLS_CERT_DIR = os.getenv("TLS_CERT_DIR", "./certs/")
|
| 30 |
+
TLS_CA_KEY_PATH = os.getenv("TLS_CA_KEY_PATH", False)
|
| 31 |
+
TLS_CA_CERT_PATH = os.getenv("TLS_CA_CERT_PATH", False)
|
| 32 |
+
|
| 33 |
+
# BlossomTune cert - To be distributed to the participants (supernodes).
|
| 34 |
+
BLOSSOMTUNE_TLS_CERT_PATH = os.getenv(
|
| 35 |
+
"BLOSSOMTUNE_TLS_CERT_PATH",
|
| 36 |
+
"/data/certs/server.crt"
|
| 37 |
+
if os.path.isdir("/data/certs")
|
| 38 |
+
else "./data/certs/server.crt",
|
| 39 |
+
)
|
blossomtune_gradio/generate_tls.py
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import shutil
|
| 3 |
+
import logging
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
from blossomtune_gradio.tls import TLSGenerator
|
| 7 |
+
from blossomtune_gradio import config as cfg
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
# Configure basic logging for the script
|
| 11 |
+
logging.basicConfig(
|
| 12 |
+
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def generate_dev_cert():
|
| 17 |
+
"""Generates a self-signed certificate for localhost development."""
|
| 18 |
+
try:
|
| 19 |
+
print("\n--- Generating self-signed certificate for localhost ---")
|
| 20 |
+
cert_dir = "certificates_localhost"
|
| 21 |
+
if os.path.exists(cert_dir):
|
| 22 |
+
shutil.rmtree(cert_dir)
|
| 23 |
+
|
| 24 |
+
generator = TLSGenerator(cert_dir=cert_dir)
|
| 25 |
+
# Note: No existing CA is passed, so a new one will be created.
|
| 26 |
+
generator.generate_server_certificate(
|
| 27 |
+
common_name="localhost", sans=["localhost", "127.0.0.1"]
|
| 28 |
+
)
|
| 29 |
+
print(f"\n✅ Success! Self-signed CA and server cert created in '{cert_dir}'.")
|
| 30 |
+
except Exception as e:
|
| 31 |
+
print(f"\n❌ An error occurred: {e}")
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def generate_prod_cert():
|
| 35 |
+
"""Generates a server certificate signed by the CA specified in config."""
|
| 36 |
+
if not cfg.TLS_CA_KEY_PATH or not cfg.TLS_CA_CERT_PATH:
|
| 37 |
+
print(
|
| 38 |
+
"\n❌ Error: TLS_CA_KEY_PATH and TLS_CA_CERT_PATH are not set in your config."
|
| 39 |
+
)
|
| 40 |
+
print("Please configure the paths to your main CA certificate and key.")
|
| 41 |
+
return
|
| 42 |
+
|
| 43 |
+
try:
|
| 44 |
+
print(
|
| 45 |
+
f"\n--- Generating production certificate signed by {cfg.TLS_CA_CERT_PATH} ---"
|
| 46 |
+
)
|
| 47 |
+
common_name = input(
|
| 48 |
+
"Enter the primary domain name for the server (e.g., fl.mydomain.com): "
|
| 49 |
+
).strip()
|
| 50 |
+
if not common_name:
|
| 51 |
+
print("Error: Domain name cannot be empty.")
|
| 52 |
+
return
|
| 53 |
+
|
| 54 |
+
generator = TLSGenerator(cert_dir=cfg.TLS_CERT_DIR)
|
| 55 |
+
generator.generate_server_certificate(
|
| 56 |
+
common_name=common_name,
|
| 57 |
+
ca_key_path=cfg.TLS_CA_KEY_PATH,
|
| 58 |
+
ca_cert_path=cfg.TLS_CA_CERT_PATH,
|
| 59 |
+
)
|
| 60 |
+
print(
|
| 61 |
+
f"\n✅ Success! Server certificate and key created in '{cfg.TLS_CERT_DIR}'."
|
| 62 |
+
)
|
| 63 |
+
except Exception as e:
|
| 64 |
+
print(f"\n❌ An error occurred: {e}")
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def main():
|
| 68 |
+
"""Main function to run the interactive menu."""
|
| 69 |
+
while True:
|
| 70 |
+
print("\n===== BlossomTune TLS Certificate Generator =====")
|
| 71 |
+
print("Select an option:")
|
| 72 |
+
print(" 1. Generate a self-signed 'localhost' certificate (for Development)")
|
| 73 |
+
print(" 2. Generate a server certificate using the main CA (for Production)")
|
| 74 |
+
print(" 3. Exit")
|
| 75 |
+
|
| 76 |
+
choice = input("Enter your choice [1]: ").strip() or "1"
|
| 77 |
+
|
| 78 |
+
if choice == "1":
|
| 79 |
+
generate_dev_cert()
|
| 80 |
+
elif choice == "2":
|
| 81 |
+
generate_prod_cert()
|
| 82 |
+
elif choice == "3":
|
| 83 |
+
print("Exiting.")
|
| 84 |
+
break
|
| 85 |
+
else:
|
| 86 |
+
print("Invalid choice. Please try again.")
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
if __name__ == "__main__":
|
| 90 |
+
main()
|
blossomtune_gradio/processing.py
CHANGED
|
@@ -37,6 +37,11 @@ def run_process(command, process_key):
|
|
| 37 |
|
| 38 |
|
| 39 |
def start_superlink():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
if process_store["superlink"] and process_store["superlink"].poll() is None:
|
| 41 |
return False, "Superlink process is already running."
|
| 42 |
command = [shutil.which("flower-superlink"), "--insecure"]
|
|
|
|
| 37 |
|
| 38 |
|
| 39 |
def start_superlink():
|
| 40 |
+
# Do not start an internal process if in external mode.
|
| 41 |
+
if cfg.SUPERLINK_MODE == "external":
|
| 42 |
+
log.warning("start_superlink called while in external mode. Operation aborted.")
|
| 43 |
+
return False, "Application is in external Superlink mode."
|
| 44 |
+
|
| 45 |
if process_store["superlink"] and process_store["superlink"].poll() is None:
|
| 46 |
return False, "Superlink process is already running."
|
| 47 |
command = [shutil.which("flower-superlink"), "--insecure"]
|
blossomtune_gradio/tls.py
ADDED
|
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import datetime
|
| 3 |
+
import logging
|
| 4 |
+
from ipaddress import ip_address
|
| 5 |
+
from cryptography import x509
|
| 6 |
+
from cryptography.x509.oid import NameOID
|
| 7 |
+
from cryptography.hazmat.primitives import hashes, serialization
|
| 8 |
+
from cryptography.hazmat.primitives.asymmetric import rsa
|
| 9 |
+
from cryptography.hazmat.backends import default_backend
|
| 10 |
+
|
| 11 |
+
log = logging.getLogger(__name__)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class TLSGenerator:
|
| 15 |
+
"""
|
| 16 |
+
A class to handle the generation of TLS certificates, keys, and CSRs.
|
| 17 |
+
This class contains the core cryptographic logic and is separated from
|
| 18 |
+
any user interface.
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
def __init__(self, cert_dir: str = "certificates"):
|
| 22 |
+
self.cert_dir = cert_dir
|
| 23 |
+
os.makedirs(self.cert_dir, exist_ok=True)
|
| 24 |
+
log.info(f"Certificate directory set to: {self.cert_dir}")
|
| 25 |
+
|
| 26 |
+
def _generate_private_key(self, filename: str) -> rsa.RSAPrivateKey:
|
| 27 |
+
"""Generates and saves a 4096-bit RSA private key."""
|
| 28 |
+
log.info(f"Generating private key: {filename}...")
|
| 29 |
+
private_key = rsa.generate_private_key(
|
| 30 |
+
public_exponent=65537, key_size=4096, backend=default_backend()
|
| 31 |
+
)
|
| 32 |
+
key_path = os.path.join(self.cert_dir, filename)
|
| 33 |
+
with open(key_path, "wb") as f:
|
| 34 |
+
f.write(
|
| 35 |
+
private_key.private_bytes(
|
| 36 |
+
encoding=serialization.Encoding.PEM,
|
| 37 |
+
format=serialization.PrivateFormat.TraditionalOpenSSL,
|
| 38 |
+
encryption_algorithm=serialization.NoEncryption(),
|
| 39 |
+
)
|
| 40 |
+
)
|
| 41 |
+
log.info(f"Private key saved to {key_path}")
|
| 42 |
+
return private_key
|
| 43 |
+
|
| 44 |
+
def create_ca(self) -> tuple[rsa.RSAPrivateKey, x509.Certificate]:
|
| 45 |
+
"""Generates a new self-signed CA key and certificate."""
|
| 46 |
+
log.info("Generating a new self-signed CA...")
|
| 47 |
+
ca_private_key = self._generate_private_key("ca.key")
|
| 48 |
+
|
| 49 |
+
subject = issuer = x509.Name(
|
| 50 |
+
[
|
| 51 |
+
x509.NameAttribute(NameOID.COUNTRY_NAME, "DE"),
|
| 52 |
+
x509.NameAttribute(NameOID.ORGANIZATION_NAME, "BlossomTune CA"),
|
| 53 |
+
x509.NameAttribute(NameOID.COMMON_NAME, "BlossomTune Self-Signed CA"),
|
| 54 |
+
]
|
| 55 |
+
)
|
| 56 |
+
cert_builder = (
|
| 57 |
+
x509.CertificateBuilder()
|
| 58 |
+
.subject_name(subject)
|
| 59 |
+
.issuer_name(issuer)
|
| 60 |
+
.public_key(ca_private_key.public_key())
|
| 61 |
+
.serial_number(x509.random_serial_number())
|
| 62 |
+
.not_valid_before(datetime.datetime.utcnow())
|
| 63 |
+
.not_valid_after(datetime.datetime.utcnow() + datetime.timedelta(days=730))
|
| 64 |
+
.add_extension(
|
| 65 |
+
x509.BasicConstraints(ca=True, path_length=None), critical=True
|
| 66 |
+
)
|
| 67 |
+
)
|
| 68 |
+
ca_cert = cert_builder.sign(ca_private_key, hashes.SHA256(), default_backend())
|
| 69 |
+
|
| 70 |
+
ca_cert_path = os.path.join(self.cert_dir, "ca.crt")
|
| 71 |
+
with open(ca_cert_path, "wb") as f:
|
| 72 |
+
f.write(ca_cert.public_bytes(serialization.Encoding.PEM))
|
| 73 |
+
log.info(f"CA certificate saved to {ca_cert_path}")
|
| 74 |
+
return ca_private_key, ca_cert
|
| 75 |
+
|
| 76 |
+
def generate_server_certificate(
|
| 77 |
+
self,
|
| 78 |
+
common_name: str,
|
| 79 |
+
sans: list[str] | None = None,
|
| 80 |
+
ca_key_path: str | None = None,
|
| 81 |
+
ca_cert_path: str | None = None,
|
| 82 |
+
):
|
| 83 |
+
"""
|
| 84 |
+
Generates a server key, signs a certificate, and creates a combined server.pem file.
|
| 85 |
+
- If a CA is provided, it uses it.
|
| 86 |
+
- Otherwise, it generates a new self-signed CA first.
|
| 87 |
+
"""
|
| 88 |
+
server_private_key = self._generate_private_key("server.key")
|
| 89 |
+
|
| 90 |
+
if ca_key_path and ca_cert_path:
|
| 91 |
+
log.info(f"Loading existing CA from {ca_cert_path}")
|
| 92 |
+
with open(ca_key_path, "rb") as f:
|
| 93 |
+
ca_private_key = serialization.load_pem_private_key(
|
| 94 |
+
f.read(), password=None
|
| 95 |
+
)
|
| 96 |
+
with open(ca_cert_path, "rb") as f:
|
| 97 |
+
ca_cert = x509.load_pem_x509_certificate(f.read())
|
| 98 |
+
else:
|
| 99 |
+
ca_private_key, ca_cert = self.create_ca()
|
| 100 |
+
|
| 101 |
+
subject = x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, common_name)])
|
| 102 |
+
san_list = [x509.DNSName(common_name)]
|
| 103 |
+
if sans:
|
| 104 |
+
for name in set(sans):
|
| 105 |
+
try:
|
| 106 |
+
san_list.append(x509.IPAddress(ip_address(name)))
|
| 107 |
+
except ValueError:
|
| 108 |
+
san_list.append(x509.DNSName(name))
|
| 109 |
+
|
| 110 |
+
builder = (
|
| 111 |
+
x509.CertificateBuilder()
|
| 112 |
+
.subject_name(subject)
|
| 113 |
+
.issuer_name(ca_cert.subject)
|
| 114 |
+
.public_key(server_private_key.public_key())
|
| 115 |
+
.serial_number(x509.random_serial_number())
|
| 116 |
+
.not_valid_before(datetime.datetime.utcnow())
|
| 117 |
+
.not_valid_after(datetime.datetime.utcnow() + datetime.timedelta(days=365))
|
| 118 |
+
.add_extension(x509.SubjectAlternativeName(san_list), critical=False)
|
| 119 |
+
)
|
| 120 |
+
server_cert = builder.sign(ca_private_key, hashes.SHA256(), default_backend())
|
| 121 |
+
|
| 122 |
+
# Save the individual server certificate (.crt)
|
| 123 |
+
server_cert_path = os.path.join(self.cert_dir, "server.crt")
|
| 124 |
+
server_cert_bytes = server_cert.public_bytes(serialization.Encoding.PEM)
|
| 125 |
+
with open(server_cert_path, "wb") as f:
|
| 126 |
+
f.write(server_cert_bytes)
|
| 127 |
+
log.info(f"Server certificate saved to {server_cert_path}")
|
| 128 |
+
|
| 129 |
+
# Create the combined server.pem file for Flower Superlink
|
| 130 |
+
server_key_bytes = server_private_key.private_bytes(
|
| 131 |
+
encoding=serialization.Encoding.PEM,
|
| 132 |
+
format=serialization.PrivateFormat.TraditionalOpenSSL,
|
| 133 |
+
encryption_algorithm=serialization.NoEncryption(),
|
| 134 |
+
)
|
| 135 |
+
server_pem_path = os.path.join(self.cert_dir, "server.pem")
|
| 136 |
+
with open(server_pem_path, "wb") as f:
|
| 137 |
+
f.write(server_cert_bytes)
|
| 138 |
+
f.write(server_key_bytes)
|
| 139 |
+
log.info(f"Server PEM file (cert + key) saved to {server_pem_path}")
|
blossomtune_gradio/ui/callbacks.py
CHANGED
|
@@ -8,6 +8,7 @@ from blossomtune_gradio.logs import log
|
|
| 8 |
from blossomtune_gradio import federation as fed
|
| 9 |
from blossomtune_gradio import processing
|
| 10 |
from blossomtune_gradio.settings import settings
|
|
|
|
| 11 |
|
| 12 |
from . import components
|
| 13 |
from . import auth
|
|
@@ -24,7 +25,7 @@ def get_full_status_update(
|
|
| 24 |
profile: gr.OAuthProfile | None, oauth_token: gr.OAuthToken | None
|
| 25 |
):
|
| 26 |
owner = auth.is_space_owner(profile, oauth_token)
|
| 27 |
-
auth_status = "Authenticating..."
|
| 28 |
is_on_space = cfg.SPACE_OWNER is not None
|
| 29 |
hf_handle_val = ""
|
| 30 |
hf_handle_interactive = not is_on_space
|
|
@@ -53,27 +54,42 @@ def get_full_status_update(
|
|
| 53 |
"SELECT participant_id, hf_handle, email, partition_id FROM requests WHERE status = 'approved' ORDER BY timestamp DESC"
|
| 54 |
).fetchall()
|
| 55 |
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
runner_is_running = (
|
| 61 |
processing.process_store["runner"]
|
| 62 |
and processing.process_store["runner"].poll() is None
|
| 63 |
)
|
| 64 |
-
|
| 65 |
-
# Hardcode status text as it's not in the schema
|
| 66 |
-
superlink_status = "🟢 Running" if superlink_is_running else "🔴 Not Running"
|
| 67 |
runner_status = "🟢 Running" if runner_is_running else "🔴 Not Running"
|
| 68 |
|
| 69 |
-
# Hardcode button text as it's not in the schema
|
| 70 |
-
if superlink_is_running:
|
| 71 |
-
superlink_btn_update = gr.update(value="🛑 Stop Superlink", variant="stop")
|
| 72 |
-
else:
|
| 73 |
-
superlink_btn_update = gr.update(
|
| 74 |
-
value="🚀 Start Superlink", variant="secondary"
|
| 75 |
-
)
|
| 76 |
-
|
| 77 |
if runner_is_running:
|
| 78 |
runner_btn_update = gr.update(value="🛑 Stop Runner", variant="stop")
|
| 79 |
else:
|
|
|
|
| 8 |
from blossomtune_gradio import federation as fed
|
| 9 |
from blossomtune_gradio import processing
|
| 10 |
from blossomtune_gradio.settings import settings
|
| 11 |
+
from blossomtune_gradio import util
|
| 12 |
|
| 13 |
from . import components
|
| 14 |
from . import auth
|
|
|
|
| 25 |
profile: gr.OAuthProfile | None, oauth_token: gr.OAuthToken | None
|
| 26 |
):
|
| 27 |
owner = auth.is_space_owner(profile, oauth_token)
|
| 28 |
+
auth_status = "Authenticating..."
|
| 29 |
is_on_space = cfg.SPACE_OWNER is not None
|
| 30 |
hf_handle_val = ""
|
| 31 |
hf_handle_interactive = not is_on_space
|
|
|
|
| 54 |
"SELECT participant_id, hf_handle, email, partition_id FROM requests WHERE status = 'approved' ORDER BY timestamp DESC"
|
| 55 |
).fetchall()
|
| 56 |
|
| 57 |
+
# Superlink Status Logic
|
| 58 |
+
superlink_btn_update = gr.update() # Default empty update
|
| 59 |
+
|
| 60 |
+
if cfg.SUPERLINK_MODE == "internal":
|
| 61 |
+
superlink_is_running = (
|
| 62 |
+
processing.process_store["superlink"]
|
| 63 |
+
and processing.process_store["superlink"].poll() is None
|
| 64 |
+
)
|
| 65 |
+
superlink_status = "🟢 Running" if superlink_is_running else "🔴 Not Running"
|
| 66 |
+
if superlink_is_running:
|
| 67 |
+
superlink_btn_update = gr.update(
|
| 68 |
+
value="🛑 Stop Superlink", variant="stop", interactive=True
|
| 69 |
+
)
|
| 70 |
+
else:
|
| 71 |
+
superlink_btn_update = gr.update(
|
| 72 |
+
value="🚀 Start Superlink", variant="secondary", interactive=True
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
elif cfg.SUPERLINK_MODE == "external":
|
| 76 |
+
if not cfg.SUPERLINK_HOST:
|
| 77 |
+
superlink_status = "🔴 Not Configured"
|
| 78 |
+
else:
|
| 79 |
+
is_open = util.is_port_open(cfg.SUPERLINK_HOST, cfg.SUPERLINK_PORT)
|
| 80 |
+
superlink_status = "🟢 Running" if is_open else "🔴 Not Running"
|
| 81 |
+
# Disable the button in external mode
|
| 82 |
+
superlink_btn_update = gr.update(value="Managed Externally", interactive=False)
|
| 83 |
+
else:
|
| 84 |
+
superlink_status = "⚠️ Invalid Mode"
|
| 85 |
+
superlink_btn_update = gr.update(interactive=False)
|
| 86 |
+
|
| 87 |
runner_is_running = (
|
| 88 |
processing.process_store["runner"]
|
| 89 |
and processing.process_store["runner"].poll() is None
|
| 90 |
)
|
|
|
|
|
|
|
|
|
|
| 91 |
runner_status = "🟢 Running" if runner_is_running else "🔴 Not Running"
|
| 92 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 93 |
if runner_is_running:
|
| 94 |
runner_btn_update = gr.update(value="🛑 Stop Runner", variant="stop")
|
| 95 |
else:
|
blossomtune_gradio/util.py
CHANGED
|
@@ -1,7 +1,32 @@
|
|
| 1 |
import re
|
|
|
|
| 2 |
import dns.resolver
|
| 3 |
|
| 4 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
def validate_email(email_address: str) -> bool:
|
| 6 |
"""
|
| 7 |
Validates an email address using regex for format and
|
|
|
|
| 1 |
import re
|
| 2 |
+
import socket
|
| 3 |
import dns.resolver
|
| 4 |
|
| 5 |
|
| 6 |
+
def is_port_open(host: str, port: int, timeout: float = 1.0) -> bool:
|
| 7 |
+
"""
|
| 8 |
+
Checks if a TCP port is open on a given host.
|
| 9 |
+
|
| 10 |
+
Args:
|
| 11 |
+
host: The hostname or IP address to check.
|
| 12 |
+
port: The port number to check.
|
| 13 |
+
timeout: The connection timeout in seconds.
|
| 14 |
+
|
| 15 |
+
Returns:
|
| 16 |
+
True if the port is open and a connection can be established,
|
| 17 |
+
False otherwise.
|
| 18 |
+
"""
|
| 19 |
+
try:
|
| 20 |
+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
| 21 |
+
s.settimeout(timeout)
|
| 22 |
+
s.connect((host, port))
|
| 23 |
+
print(f"TCP check successful: Port {port} is open on {host}.")
|
| 24 |
+
return True
|
| 25 |
+
except (socket.timeout, ConnectionRefusedError, OSError) as e:
|
| 26 |
+
print(f"TCP check failed: Port {port} on {host} is not open. Error: {e}")
|
| 27 |
+
return False
|
| 28 |
+
|
| 29 |
+
|
| 30 |
def validate_email(email_address: str) -> bool:
|
| 31 |
"""
|
| 32 |
Validates an email address using regex for format and
|
data/cache/.gitkeep
ADDED
|
File without changes
|
data/certs/.gitkeep
ADDED
|
File without changes
|
data/db/.gitkeep
ADDED
|
File without changes
|
data/keys/.gitkeep
ADDED
|
File without changes
|
docker-compose.yaml
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
services:
|
| 2 |
+
gradio_app:
|
| 3 |
+
image: ethicalabs/blossomtune-orchestrator:latest
|
| 4 |
+
container_name: gradio_app
|
| 5 |
+
build:
|
| 6 |
+
context: .
|
| 7 |
+
dockerfile: Dockerfile
|
| 8 |
+
command: ["gradio_app"]
|
| 9 |
+
ports:
|
| 10 |
+
- "7860:7860" # Expose the Gradio port to the host machine
|
| 11 |
+
volumes:
|
| 12 |
+
- ./data/db:/data/db # Mount the database directory for persistence
|
| 13 |
+
- ./data/certs:/data/certs:ro # Mount TLS certificates (read-only)
|
| 14 |
+
- ./data/keys:/data/keys:rw # Mount authentication keys (read-write)
|
| 15 |
+
- ./data/cache:/root/.cache # Mount Hugging Face cache for models/datasets
|
| 16 |
+
depends_on:
|
| 17 |
+
- superlink # Optional: Ensures superlink starts before the UI
|
| 18 |
+
environment:
|
| 19 |
+
HF_TOKEN: ${HF_TOKEN}
|
| 20 |
+
SUPERLINK_MODE: external
|
| 21 |
+
SUPERLINK_HOST: host.docker.internal
|
| 22 |
+
superlink:
|
| 23 |
+
image: ethicalabs/blossomtune-orchestrator:latest
|
| 24 |
+
container_name: superlink
|
| 25 |
+
build:
|
| 26 |
+
context: .
|
| 27 |
+
dockerfile: Dockerfile
|
| 28 |
+
command: ["superlink"]
|
| 29 |
+
ports:
|
| 30 |
+
- "127.0.0.1:9092:9092" # Port for SuperNode connections
|
| 31 |
+
- "9093:9093" # Port for Flower CLI (e.g., flwr run)
|
| 32 |
+
volumes:
|
| 33 |
+
- ./data/certs:/data/certs:ro # Mount TLS certificates (read-only)
|
| 34 |
+
- ./data/keys:/data/keys:ro # Mount authentication keys (read-only)
|
| 35 |
+
- ./data/results:/data/results # Mount results directory for persisting run artifacts
|
| 36 |
+
mailhog:
|
| 37 |
+
image: mailhog/mailhog
|
| 38 |
+
container_name: mailhog
|
| 39 |
+
restart: always
|
| 40 |
+
volumes:
|
| 41 |
+
- ./mailhog.auth:/mailhog.auth:ro
|
| 42 |
+
- ./data:/data:rw
|
| 43 |
+
ports:
|
| 44 |
+
- "1025:1025"
|
| 45 |
+
- "8025:8025"
|
docker_entrypoint.sh
CHANGED
|
@@ -8,6 +8,12 @@ env | grep -v -E "SECRET|KEY|TOKEN|PASSWORD|ACCESS"
|
|
| 8 |
if [ "${1}" = "gradio_app" ]; then
|
| 9 |
echo "Running Gradio App..."
|
| 10 |
exec python3 -m blossomtune_gradio
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
else
|
| 12 |
exec "$@"
|
| 13 |
fi
|
|
|
|
| 8 |
if [ "${1}" = "gradio_app" ]; then
|
| 9 |
echo "Running Gradio App..."
|
| 10 |
exec python3 -m blossomtune_gradio
|
| 11 |
+
elif [ "${1}" = "superlink" ]; then
|
| 12 |
+
echo "Running Superlink..."
|
| 13 |
+
exec flower-superlink \
|
| 14 |
+
--ssl-ca-certfile /data/certs/ca.crt \
|
| 15 |
+
--ssl-certfile /data/certs/server.pem \
|
| 16 |
+
--ssl-keyfile /data/certs/server.key
|
| 17 |
else
|
| 18 |
exec "$@"
|
| 19 |
fi
|
pyproject.toml
CHANGED
|
@@ -68,6 +68,7 @@ convention = "google" # Accepts: "google", "numpy", or "pep257".
|
|
| 68 |
|
| 69 |
[dependency-groups]
|
| 70 |
dev = [
|
|
|
|
| 71 |
"dnspython>=2.8.0",
|
| 72 |
"pytest>=8.4.1",
|
| 73 |
"pytest-mock>=3.15.1",
|
|
|
|
| 68 |
|
| 69 |
[dependency-groups]
|
| 70 |
dev = [
|
| 71 |
+
"cryptography>=44.0.3",
|
| 72 |
"dnspython>=2.8.0",
|
| 73 |
"pytest>=8.4.1",
|
| 74 |
"pytest-mock>=3.15.1",
|
tests/test_auth_keys.py
ADDED
|
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import csv
|
| 3 |
+
import stat
|
| 4 |
+
import pytest
|
| 5 |
+
from cryptography.hazmat.primitives import serialization
|
| 6 |
+
from cryptography.hazmat.primitives.asymmetric import ec
|
| 7 |
+
|
| 8 |
+
from blossomtune_gradio.auth_keys import AuthKeyGenerator, rebuild_authorized_keys_csv
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@pytest.fixture
|
| 12 |
+
def key_generator(tmp_path):
|
| 13 |
+
"""Fixture to create an AuthKeyGenerator instance in a temporary directory."""
|
| 14 |
+
key_dir = tmp_path / "auth_keys"
|
| 15 |
+
return AuthKeyGenerator(key_dir=str(key_dir))
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class TestAuthKeyGenerator:
|
| 19 |
+
"""Test suite for the AuthKeyGenerator class."""
|
| 20 |
+
|
| 21 |
+
def test_init_creates_directory(self, tmp_path):
|
| 22 |
+
"""Verify that the key directory is created on initialization."""
|
| 23 |
+
key_dir = tmp_path / "new_keys"
|
| 24 |
+
assert not os.path.exists(key_dir)
|
| 25 |
+
AuthKeyGenerator(key_dir=str(key_dir))
|
| 26 |
+
assert os.path.exists(key_dir)
|
| 27 |
+
|
| 28 |
+
def test_generate_participant_keys_creates_files_and_returns_data(
|
| 29 |
+
self, key_generator
|
| 30 |
+
):
|
| 31 |
+
"""
|
| 32 |
+
Verify that the main method generates all expected files and returns
|
| 33 |
+
the correct data tuple.
|
| 34 |
+
"""
|
| 35 |
+
participant_id = "participant_01"
|
| 36 |
+
priv_path, pub_path, pub_pem = key_generator.generate_participant_keys(
|
| 37 |
+
participant_id
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
# 1. Check if files exist
|
| 41 |
+
assert os.path.exists(priv_path)
|
| 42 |
+
assert os.path.exists(pub_path)
|
| 43 |
+
assert priv_path == os.path.join(key_generator.key_dir, f"{participant_id}.key")
|
| 44 |
+
assert pub_path == os.path.join(key_generator.key_dir, f"{participant_id}.pub")
|
| 45 |
+
|
| 46 |
+
# 2. Check private key file permissions (security check)
|
| 47 |
+
# In non-Windows environments, check for 600 permissions.
|
| 48 |
+
if os.name != "nt":
|
| 49 |
+
file_mode = stat.S_IMODE(os.stat(priv_path).st_mode)
|
| 50 |
+
assert file_mode == 0o600
|
| 51 |
+
|
| 52 |
+
# 3. Verify key formats and consistency
|
| 53 |
+
with open(priv_path, "rb") as f:
|
| 54 |
+
private_key = serialization.load_pem_private_key(f.read(), password=None)
|
| 55 |
+
with open(pub_path, "rb") as f:
|
| 56 |
+
public_key = serialization.load_pem_public_key(f.read())
|
| 57 |
+
_ = f.read() # Read again to get bytes
|
| 58 |
+
|
| 59 |
+
assert isinstance(private_key, ec.EllipticCurvePrivateKey)
|
| 60 |
+
assert isinstance(public_key, ec.EllipticCurvePublicKey)
|
| 61 |
+
|
| 62 |
+
# Check that the returned PEM string matches the public key
|
| 63 |
+
generated_public_key = private_key.public_key()
|
| 64 |
+
pem_from_private = generated_public_key.public_bytes(
|
| 65 |
+
encoding=serialization.Encoding.PEM,
|
| 66 |
+
format=serialization.PublicFormat.SubjectPublicKeyInfo,
|
| 67 |
+
).decode("utf-8")
|
| 68 |
+
assert pub_pem == pem_from_private
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
class TestRebuildCSV:
|
| 72 |
+
"""Test suite for the rebuild_authorized_keys_csv function."""
|
| 73 |
+
|
| 74 |
+
def test_rebuild_csv_creates_file_with_header_for_empty_list(self, tmp_path):
|
| 75 |
+
"""Verify a CSV with only a header is created for an empty participant list."""
|
| 76 |
+
key_dir = tmp_path / "csv_test"
|
| 77 |
+
os.makedirs(key_dir)
|
| 78 |
+
csv_path = os.path.join(key_dir, "authorized_supernodes.csv")
|
| 79 |
+
|
| 80 |
+
rebuild_authorized_keys_csv(key_dir, [])
|
| 81 |
+
|
| 82 |
+
assert os.path.exists(csv_path)
|
| 83 |
+
with open(csv_path, "r") as f:
|
| 84 |
+
reader = csv.reader(f)
|
| 85 |
+
header = next(reader)
|
| 86 |
+
assert header == ["participant_id", "public_key_pem"]
|
| 87 |
+
# Check that there are no more rows
|
| 88 |
+
with pytest.raises(StopIteration):
|
| 89 |
+
next(reader)
|
| 90 |
+
|
| 91 |
+
def test_rebuild_csv_writes_correct_data(self, tmp_path):
|
| 92 |
+
"""Verify the CSV is created with the correct participant data."""
|
| 93 |
+
key_dir = tmp_path / "csv_test"
|
| 94 |
+
os.makedirs(key_dir)
|
| 95 |
+
csv_path = os.path.join(key_dir, "authorized_supernodes.csv")
|
| 96 |
+
|
| 97 |
+
participants = [
|
| 98 |
+
("p1", "---BEGIN PUBLIC KEY---...p1...---END PUBLIC KEY---"),
|
| 99 |
+
("p2", "---BEGIN PUBLIC KEY---...p2...---END PUBLIC KEY---"),
|
| 100 |
+
]
|
| 101 |
+
rebuild_authorized_keys_csv(key_dir, participants)
|
| 102 |
+
|
| 103 |
+
with open(csv_path, "r") as f:
|
| 104 |
+
reader = csv.reader(f)
|
| 105 |
+
header = next(reader)
|
| 106 |
+
row1 = next(reader)
|
| 107 |
+
row2 = next(reader)
|
| 108 |
+
assert header == ["participant_id", "public_key_pem"]
|
| 109 |
+
assert row1 == list(participants[0])
|
| 110 |
+
assert row2 == list(participants[1])
|
| 111 |
+
|
| 112 |
+
def test_rebuild_csv_overwrites_existing_file(self, tmp_path):
|
| 113 |
+
"""Verify that an existing CSV file is correctly overwritten."""
|
| 114 |
+
key_dir = tmp_path / "csv_test"
|
| 115 |
+
os.makedirs(key_dir)
|
| 116 |
+
csv_path = os.path.join(key_dir, "authorized_supernodes.csv")
|
| 117 |
+
|
| 118 |
+
# First run with initial data
|
| 119 |
+
initial_participants = [("old_p1", "old_key_1")]
|
| 120 |
+
rebuild_authorized_keys_csv(key_dir, initial_participants)
|
| 121 |
+
|
| 122 |
+
# Second run with new data
|
| 123 |
+
new_participants = [("new_p1", "new_key_1"), ("new_p2", "new_key_2")]
|
| 124 |
+
rebuild_authorized_keys_csv(key_dir, new_participants)
|
| 125 |
+
|
| 126 |
+
with open(csv_path, "r") as f:
|
| 127 |
+
reader = csv.reader(f)
|
| 128 |
+
_ = next(reader)
|
| 129 |
+
rows = list(reader)
|
| 130 |
+
|
| 131 |
+
assert len(rows) == 2
|
| 132 |
+
assert rows[0] == list(new_participants[0])
|
| 133 |
+
assert rows[1] == list(new_participants[1])
|
tests/test_blossomfile.py
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import zipfile
|
| 4 |
+
import pytest
|
| 5 |
+
|
| 6 |
+
from blossomtune_gradio.blossomfile import create_blossomfile
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
@pytest.fixture
|
| 10 |
+
def dummy_credential_files(tmp_path):
|
| 11 |
+
"""
|
| 12 |
+
Pytest fixture to create dummy credential files in a temporary directory.
|
| 13 |
+
Returns the paths to the created files.
|
| 14 |
+
"""
|
| 15 |
+
creds_dir = tmp_path / "creds"
|
| 16 |
+
os.makedirs(creds_dir)
|
| 17 |
+
|
| 18 |
+
ca_cert_path = creds_dir / "ca.crt"
|
| 19 |
+
auth_key_path = creds_dir / "auth.key"
|
| 20 |
+
auth_pub_path = creds_dir / "auth.pub"
|
| 21 |
+
|
| 22 |
+
ca_cert_path.write_text("---BEGIN CERTIFICATE---")
|
| 23 |
+
auth_key_path.write_text("---BEGIN EC PRIVATE KEY---")
|
| 24 |
+
auth_pub_path.write_text("---BEGIN PUBLIC KEY---")
|
| 25 |
+
|
| 26 |
+
return {
|
| 27 |
+
"ca_cert_path": str(ca_cert_path),
|
| 28 |
+
"auth_key_path": str(auth_key_path),
|
| 29 |
+
"auth_pub_path": str(auth_pub_path),
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def test_create_blossomfile_success(tmp_path, dummy_credential_files):
|
| 34 |
+
"""
|
| 35 |
+
Tests the successful creation of a .blossomfile, verifying its contents.
|
| 36 |
+
"""
|
| 37 |
+
output_dir = tmp_path / "output"
|
| 38 |
+
participant_id = "participant_abc"
|
| 39 |
+
|
| 40 |
+
blossomfile_path = create_blossomfile(
|
| 41 |
+
participant_id=participant_id,
|
| 42 |
+
output_dir=str(output_dir),
|
| 43 |
+
ca_cert_path=dummy_credential_files["ca_cert_path"],
|
| 44 |
+
auth_key_path=dummy_credential_files["auth_key_path"],
|
| 45 |
+
auth_pub_path=dummy_credential_files["auth_pub_path"],
|
| 46 |
+
superlink_address="blossomtune-test.ethicalabs.ai:9092",
|
| 47 |
+
partition_id=5,
|
| 48 |
+
num_partitions=10,
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
# 1. Verify the file was created at the correct path
|
| 52 |
+
assert os.path.exists(blossomfile_path)
|
| 53 |
+
assert blossomfile_path == str(output_dir / f"{participant_id}.blossomfile")
|
| 54 |
+
|
| 55 |
+
# 2. Verify the contents of the zip archive
|
| 56 |
+
with zipfile.ZipFile(blossomfile_path, "r") as zf:
|
| 57 |
+
# Check that all expected files are present
|
| 58 |
+
namelist = zf.namelist()
|
| 59 |
+
assert "blossom.json" in namelist
|
| 60 |
+
assert "ca.crt" in namelist
|
| 61 |
+
assert "auth.key" in namelist
|
| 62 |
+
assert "auth.pub" in namelist
|
| 63 |
+
assert len(namelist) == 4
|
| 64 |
+
|
| 65 |
+
# Check the content of blossom.json
|
| 66 |
+
with zf.open("blossom.json") as f:
|
| 67 |
+
config_data = json.load(f)
|
| 68 |
+
assert (
|
| 69 |
+
config_data["superlink_address"]
|
| 70 |
+
== "blossomtune-test.ethicalabs.ai:9092"
|
| 71 |
+
)
|
| 72 |
+
assert config_data["node_config"]["partition-id"] == 5
|
| 73 |
+
assert config_data["node_config"]["num-partitions"] == 10
|
| 74 |
+
|
| 75 |
+
# Check the content of the credential files
|
| 76 |
+
assert zf.read("ca.crt").decode("utf-8") == "---BEGIN CERTIFICATE---"
|
| 77 |
+
assert zf.read("auth.key").decode("utf-8") == "---BEGIN EC PRIVATE KEY---"
|
| 78 |
+
assert zf.read("auth.pub").decode("utf-8") == "---BEGIN PUBLIC KEY---"
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def test_create_blossomfile_missing_input_file(tmp_path):
|
| 82 |
+
"""
|
| 83 |
+
Tests that the function raises FileNotFoundError if a required credential
|
| 84 |
+
file is missing and cleans up the partial archive.
|
| 85 |
+
"""
|
| 86 |
+
output_dir = tmp_path / "output"
|
| 87 |
+
participant_id = "participant_xyz"
|
| 88 |
+
missing_file_path = tmp_path / "creds" / "non_existent.key"
|
| 89 |
+
blossomfile_path = output_dir / f"{participant_id}.blossomfile"
|
| 90 |
+
|
| 91 |
+
with pytest.raises(FileNotFoundError, match="Required credential file not found"):
|
| 92 |
+
create_blossomfile(
|
| 93 |
+
participant_id=participant_id,
|
| 94 |
+
output_dir=str(output_dir),
|
| 95 |
+
ca_cert_path=str(missing_file_path), # Pass a path that doesn't exist
|
| 96 |
+
auth_key_path="dummy",
|
| 97 |
+
auth_pub_path="dummy",
|
| 98 |
+
superlink_address="test:9092",
|
| 99 |
+
partition_id=1,
|
| 100 |
+
num_partitions=2,
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
# Verify that the partially created blossomfile was removed
|
| 104 |
+
assert not os.path.exists(blossomfile_path)
|
tests/test_tls.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import pytest
|
| 3 |
+
from cryptography import x509
|
| 4 |
+
|
| 5 |
+
from blossomtune_gradio.tls import TLSGenerator
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
@pytest.fixture
|
| 9 |
+
def tls_generator(tmp_path):
|
| 10 |
+
"""Fixture to create a TLSGenerator instance in a temporary directory."""
|
| 11 |
+
cert_dir = tmp_path / "certs"
|
| 12 |
+
return TLSGenerator(cert_dir=str(cert_dir))
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class TestTLSGenerator:
|
| 16 |
+
"""Test suite for the TLSGenerator class."""
|
| 17 |
+
|
| 18 |
+
def test_init_creates_directory(self, tmp_path):
|
| 19 |
+
"""Verify that the certificate directory is created on initialization."""
|
| 20 |
+
cert_dir = tmp_path / "new_certs"
|
| 21 |
+
assert not os.path.exists(cert_dir)
|
| 22 |
+
TLSGenerator(cert_dir=str(cert_dir))
|
| 23 |
+
assert os.path.exists(cert_dir)
|
| 24 |
+
|
| 25 |
+
def test_create_ca(self, tls_generator):
|
| 26 |
+
"""Test the creation of a self-signed Certificate Authority."""
|
| 27 |
+
ca_key, ca_cert = tls_generator.create_ca()
|
| 28 |
+
|
| 29 |
+
assert os.path.exists(os.path.join(tls_generator.cert_dir, "ca.key"))
|
| 30 |
+
assert os.path.exists(os.path.join(tls_generator.cert_dir, "ca.crt"))
|
| 31 |
+
assert ca_cert.issuer == ca_cert.subject
|
| 32 |
+
assert (
|
| 33 |
+
ca_cert.extensions.get_extension_for_class(x509.BasicConstraints).value.ca
|
| 34 |
+
is True
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
def test_generate_server_certificate_with_new_ca(self, tls_generator):
|
| 38 |
+
"""
|
| 39 |
+
Test generating a server certificate, which should also create a new CA
|
| 40 |
+
and the combined server.pem file.
|
| 41 |
+
"""
|
| 42 |
+
common_name = "test.local"
|
| 43 |
+
sans = ["test.local", "192.168.1.10"]
|
| 44 |
+
tls_generator.generate_server_certificate(common_name=common_name, sans=sans)
|
| 45 |
+
|
| 46 |
+
# Check for all expected files
|
| 47 |
+
assert os.path.exists(os.path.join(tls_generator.cert_dir, "ca.key"))
|
| 48 |
+
assert os.path.exists(os.path.join(tls_generator.cert_dir, "ca.crt"))
|
| 49 |
+
assert os.path.exists(os.path.join(tls_generator.cert_dir, "server.key"))
|
| 50 |
+
assert os.path.exists(os.path.join(tls_generator.cert_dir, "server.crt"))
|
| 51 |
+
assert os.path.exists(os.path.join(tls_generator.cert_dir, "server.pem"))
|
| 52 |
+
|
| 53 |
+
# Load certs to verify issuer relationship
|
| 54 |
+
with open(os.path.join(tls_generator.cert_dir, "ca.crt"), "rb") as f:
|
| 55 |
+
ca_cert = x509.load_pem_x509_certificate(f.read())
|
| 56 |
+
with open(os.path.join(tls_generator.cert_dir, "server.crt"), "rb") as f:
|
| 57 |
+
server_cert = x509.load_pem_x509_certificate(f.read())
|
| 58 |
+
|
| 59 |
+
assert server_cert.issuer == ca_cert.subject
|
| 60 |
+
|
| 61 |
+
# Verify PEM content
|
| 62 |
+
with open(os.path.join(tls_generator.cert_dir, "server.pem"), "r") as f:
|
| 63 |
+
pem_content = f.read()
|
| 64 |
+
assert "-----BEGIN CERTIFICATE-----" in pem_content
|
| 65 |
+
assert "-----BEGIN RSA PRIVATE KEY-----" in pem_content
|
| 66 |
+
|
| 67 |
+
def test_generate_server_certificate_with_existing_ca(self, tls_generator):
|
| 68 |
+
"""Test generating a server certificate using a pre-existing CA."""
|
| 69 |
+
tls_generator.create_ca()
|
| 70 |
+
ca_key_path = os.path.join(tls_generator.cert_dir, "ca.key")
|
| 71 |
+
ca_cert_path = os.path.join(tls_generator.cert_dir, "ca.crt")
|
| 72 |
+
|
| 73 |
+
server_gen_dir = os.path.join(
|
| 74 |
+
os.path.dirname(tls_generator.cert_dir), "server_certs"
|
| 75 |
+
)
|
| 76 |
+
server_generator = TLSGenerator(cert_dir=server_gen_dir)
|
| 77 |
+
|
| 78 |
+
server_generator.generate_server_certificate(
|
| 79 |
+
common_name="prod.local", ca_key_path=ca_key_path, ca_cert_path=ca_cert_path
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
# Check that server files were created in the new directory
|
| 83 |
+
assert os.path.exists(os.path.join(server_gen_dir, "server.key"))
|
| 84 |
+
assert os.path.exists(os.path.join(server_gen_dir, "server.crt"))
|
| 85 |
+
assert os.path.exists(os.path.join(server_gen_dir, "server.pem"))
|
| 86 |
+
# Check that a *new* CA was NOT created in the server directory
|
| 87 |
+
assert not os.path.exists(os.path.join(server_gen_dir, "ca.key"))
|
| 88 |
+
|
| 89 |
+
# Verify PEM content
|
| 90 |
+
with open(os.path.join(server_gen_dir, "server.pem"), "r") as f:
|
| 91 |
+
pem_content = f.read()
|
| 92 |
+
assert "-----BEGIN CERTIFICATE-----" in pem_content
|
| 93 |
+
assert "-----BEGIN RSA PRIVATE KEY-----" in pem_content
|
tests/test_util.py
CHANGED
|
@@ -1,8 +1,43 @@
|
|
| 1 |
import pytest
|
|
|
|
| 2 |
import dns.resolver
|
| 3 |
from unittest.mock import MagicMock
|
| 4 |
|
| 5 |
-
from blossomtune_gradio.util import validate_email, strtobool
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
|
| 8 |
def test_validate_email_valid(monkeypatch):
|
|
|
|
| 1 |
import pytest
|
| 2 |
+
import socket
|
| 3 |
import dns.resolver
|
| 4 |
from unittest.mock import MagicMock
|
| 5 |
|
| 6 |
+
from blossomtune_gradio.util import is_port_open, validate_email, strtobool
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def test_is_port_open_success(mocker):
|
| 10 |
+
"""
|
| 11 |
+
Tests the case where the port is open and the connection succeeds.
|
| 12 |
+
"""
|
| 13 |
+
mock_socket = mocker.patch("blossomtune_gradio.util.socket.socket")
|
| 14 |
+
mock_socket.return_value.__enter__.return_value.connect.return_value = None
|
| 15 |
+
|
| 16 |
+
result = is_port_open("testhost", 1234)
|
| 17 |
+
|
| 18 |
+
assert result is True
|
| 19 |
+
# Verify that the socket was created and connect was called with the correct args
|
| 20 |
+
mock_socket.return_value.__enter__.return_value.settimeout.assert_called_once_with(
|
| 21 |
+
1.0
|
| 22 |
+
)
|
| 23 |
+
mock_socket.return_value.__enter__.return_value.connect.assert_called_once_with(
|
| 24 |
+
("testhost", 1234)
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
@pytest.mark.parametrize("exception", [ConnectionRefusedError, socket.timeout, OSError])
|
| 29 |
+
def test_is_port_open_failures(mocker, exception):
|
| 30 |
+
"""
|
| 31 |
+
Tests various failure scenarios where the port is not open.
|
| 32 |
+
- ConnectionRefusedError: The host actively refuses the connection.
|
| 33 |
+
- socket.timeout: The connection attempt times out.
|
| 34 |
+
- OSError: A generic network error occurs.
|
| 35 |
+
"""
|
| 36 |
+
mock_socket = mocker.patch("blossomtune_gradio.util.socket.socket")
|
| 37 |
+
mock_socket.return_value.__enter__.return_value.connect.side_effect = exception
|
| 38 |
+
|
| 39 |
+
result = is_port_open("testhost", 1234)
|
| 40 |
+
assert result is False
|
| 41 |
|
| 42 |
|
| 43 |
def test_validate_email_valid(monkeypatch):
|
uv.lock
CHANGED
|
@@ -241,6 +241,7 @@ dependencies = [
|
|
| 241 |
|
| 242 |
[package.dev-dependencies]
|
| 243 |
dev = [
|
|
|
|
| 244 |
{ name = "dnspython" },
|
| 245 |
{ name = "pytest" },
|
| 246 |
{ name = "pytest-mock" },
|
|
@@ -262,6 +263,7 @@ requires-dist = [
|
|
| 262 |
|
| 263 |
[package.metadata.requires-dev]
|
| 264 |
dev = [
|
|
|
|
| 265 |
{ name = "dnspython", specifier = ">=2.8.0" },
|
| 266 |
{ name = "pytest", specifier = ">=8.4.1" },
|
| 267 |
{ name = "pytest-mock", specifier = ">=3.15.1" },
|
|
|
|
| 241 |
|
| 242 |
[package.dev-dependencies]
|
| 243 |
dev = [
|
| 244 |
+
{ name = "cryptography" },
|
| 245 |
{ name = "dnspython" },
|
| 246 |
{ name = "pytest" },
|
| 247 |
{ name = "pytest-mock" },
|
|
|
|
| 263 |
|
| 264 |
[package.metadata.requires-dev]
|
| 265 |
dev = [
|
| 266 |
+
{ name = "cryptography", specifier = ">=44.0.3" },
|
| 267 |
{ name = "dnspython", specifier = ">=2.8.0" },
|
| 268 |
{ name = "pytest", specifier = ">=8.4.1" },
|
| 269 |
{ name = "pytest-mock", specifier = ">=3.15.1" },
|