Spaces:
Sleeping
Sleeping
File size: 9,248 Bytes
bc67f56 3e25ded bc67f56 3e25ded 4a4c0cf e8143a3 1c0aad9 bc67f56 3e25ded 1c0aad9 3e25ded 1c0aad9 3e25ded 1c0aad9 3e25ded 1c0aad9 bc67f56 3e25ded bc67f56 1c0aad9 bc67f56 1c0aad9 bc67f56 1c0aad9 bc67f56 1c0aad9 3e25ded 1c0aad9 3e25ded 1c0aad9 3e25ded e8143a3 3e25ded 1c0aad9 3e25ded 1c0aad9 3e25ded e8143a3 3e25ded 1c0aad9 bc67f56 1c0aad9 bc67f56 3e25ded bc67f56 3e25ded 1c0aad9 bc67f56 1c0aad9 bc67f56 3e25ded bc67f56 3e25ded 1c0aad9 3e25ded 1c0aad9 3e25ded |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 |
import os
import string
import secrets
import tempfile
from blossomtune_gradio import config as cfg
from blossomtune_gradio import mail
from blossomtune_gradio import util
from blossomtune_gradio.settings import settings
from blossomtune_gradio.database import SessionLocal, Request, Config
from blossomtune_gradio.auth_keys import AuthKeyGenerator, rebuild_authorized_keys_csv
from blossomtune_gradio.blossomfile import create_blossomfile
def generate_participant_id(length=6):
"""Generates a random, uppercase alphanumeric participant ID."""
alphabet = string.ascii_uppercase + string.digits
return "".join(secrets.choice(alphabet) for _ in range(length))
def generate_activation_code(length=8):
"""Generates a random, uppercase alphanumeric activation code."""
alphabet = string.ascii_uppercase + string.digits
return "".join(secrets.choice(alphabet) for _ in range(length))
def check_participant_status(pid_to_check: str, email: str, activation_code: str):
"""
Handles a participant's request to join, activate, or check status using SQLAlchemy.
Returns a tuple: (is_approved: bool, message: str, data: any | None)
"""
with SessionLocal() as db:
query = db.query(Request).filter(
Request.hf_handle == pid_to_check, Request.email == email
)
if activation_code:
query = query.filter(Request.activation_code == activation_code)
request = query.first()
num_partitions_config = (
db.query(Config).filter(Config.key == "num_partitions").first()
)
num_partitions = num_partitions_config.value if num_partitions_config else "10"
# Case 1 & 2 are for users not yet approved
if not request or not request.is_activated or not activation_code:
if request is None:
if activation_code:
return (False, settings.get_text("activation_invalid_md"), None)
if not util.validate_email(email):
return (False, settings.get_text("invalid_email_md"), None)
approved_count = (
db.query(Request).filter(Request.status == "approved").count()
)
if approved_count >= cfg.MAX_NUM_NODES:
return (False, settings.get_text("federation_full_md"), None)
participant_id = generate_participant_id()
new_activation_code = generate_activation_code()
mail_sent, message = mail.send_activation_email(
email, new_activation_code
)
if mail_sent:
new_request = Request(
participant_id=participant_id,
hf_handle=pid_to_check,
email=email,
activation_code=new_activation_code,
)
db.add(new_request)
db.commit()
return (False, settings.get_text("registration_submitted_md"), None)
else:
return (False, message, None)
if not request.is_activated:
if activation_code == request.activation_code:
request.is_activated = 1
db.commit()
return (False, settings.get_text("activation_successful_md"), None)
else:
return (False, settings.get_text("activation_invalid_md"), None)
if not activation_code:
return (False, settings.get_text("missing_activation_code_md"), None)
# Case 3: Activated user is checking their final status
if request.status == "approved":
hostname = (
"localhost"
if not cfg.SPACE_ID
else f"{cfg.SPACE_ID.split('/')[1]}-{cfg.SPACE_ID.split('/')[0]}.hf.space"
)
superlink_address = f"{cfg.SUPERLINK_HOST or hostname}:{cfg.SUPERLINK_PORT}"
# Blossomfile Generation
blossomfile_tempdir = tempfile.mkdtemp() # TODO: remove tempdirs
try:
blossomfile_path = create_blossomfile(
participant_id=request.participant_id,
output_dir=blossomfile_tempdir,
ca_cert_path=cfg.BLOSSOMTUNE_TLS_CA_CERTFILE,
auth_key_path=os.path.join(
cfg.AUTH_KEYS_DIR, f"{request.participant_id}.key"
),
auth_pub_path=os.path.join(
cfg.AUTH_KEYS_DIR, f"{request.participant_id}.pub"
),
superlink_address=superlink_address,
partition_id=request.partition_id,
num_partitions=int(num_partitions),
)
except FileNotFoundError:
return (False, "An error occurred.", None)
connection_string = settings.get_text(
"status_approved_md",
participant_id=request.participant_id,
partition_id=request.partition_id,
superlink_hostname=superlink_address.split(":")[0],
superlink_port=superlink_address.split(":")[1],
num_partitions=num_partitions,
)
return (True, connection_string, blossomfile_path)
elif request.status == "pending":
return (False, settings.get_text("status_pending_md"), None)
else: # Denied
return (
False,
settings.get_text(
"status_denied_md", participant_id=request.participant_id
),
None,
)
def manage_request(participant_id: str, partition_id: str, action: str):
"""Admin function to approve/deny a request and assign a partition ID."""
if not participant_id:
return False, "Please select a participant from the pending requests table."
with SessionLocal() as db:
request = (
db.query(Request).filter(Request.participant_id == participant_id).first()
)
if not request:
return False, "Participant not found."
if action == "approve":
if not partition_id or not partition_id.isdigit():
return False, "Please provide a valid integer for the Partition ID."
p_id_int = int(partition_id)
if not request.is_activated:
return (
False,
settings.get_text("participant_not_activated_warning_md"),
)
existing_participant = (
db.query(Request)
.filter(Request.partition_id == p_id_int, Request.status == "approved")
.first()
)
if existing_participant:
return (
False,
settings.get_text(
"partition_in_use_warning_md", partition_id=p_id_int
),
)
request.status = "approved"
request.partition_id = p_id_int
# Generate and Store Auth Keys
key_generator = AuthKeyGenerator(key_dir=cfg.AUTH_KEYS_DIR)
_, _, public_key_pem = key_generator.generate_participant_keys(
participant_id
)
request.public_key_pem = public_key_pem
db.commit()
# Rebuild Authorized Keys CSV
approved_participants = (
db.query(Request.participant_id, Request.public_key_pem)
.filter(
Request.status == "approved", Request.public_key_pem.isnot(None)
)
.all()
)
rebuild_authorized_keys_csv(cfg.AUTH_KEYS_DIR, approved_participants)
return (
True,
f"Participant {participant_id} approved. Keys generated and registry updated.",
)
else: # Deny
request.status = "denied"
request.partition_id = None
request.public_key_pem = None
db.commit()
# --- Rebuild CSV after denial to revoke access ---
approved_participants = (
db.query(Request.participant_id, Request.public_key_pem)
.filter(
Request.status == "approved", Request.public_key_pem.isnot(None)
)
.all()
)
rebuild_authorized_keys_csv(cfg.AUTH_KEYS_DIR, approved_participants)
return (
True,
f"Participant {participant_id} denied. Their access has been revoked.",
)
def get_next_partion_id() -> int:
"""Finds the lowest available partition ID."""
with SessionLocal() as db:
used_ids_query = (
db.query(Request.partition_id)
.filter(Request.status == "approved", Request.partition_id.isnot(None))
.all()
)
used_ids = {row[0] for row in used_ids_query}
next_id = 0
while next_id in used_ids:
next_id += 1
return next_id
|