File size: 9,248 Bytes
bc67f56
3e25ded
 
bc67f56
3e25ded
 
 
4a4c0cf
e8143a3
1c0aad9
bc67f56
 
3e25ded
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1c0aad9
 
 
 
 
 
 
 
3e25ded
1c0aad9
3e25ded
1c0aad9
3e25ded
1c0aad9
 
 
 
 
bc67f56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3e25ded
bc67f56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1c0aad9
 
 
 
 
 
 
 
bc67f56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1c0aad9
 
 
 
 
bc67f56
 
1c0aad9
 
bc67f56
1c0aad9
 
 
 
 
 
 
 
 
 
3e25ded
 
 
 
 
 
 
1c0aad9
 
 
 
 
 
3e25ded
1c0aad9
 
 
 
 
3e25ded
 
e8143a3
3e25ded
 
1c0aad9
 
 
 
3e25ded
1c0aad9
3e25ded
 
e8143a3
 
 
3e25ded
 
1c0aad9
 
bc67f56
 
 
 
 
 
 
1c0aad9
bc67f56
 
 
 
 
 
 
 
 
 
 
3e25ded
 
bc67f56
3e25ded
1c0aad9
 
 
bc67f56
1c0aad9
bc67f56
 
 
 
 
 
 
 
 
 
 
3e25ded
 
bc67f56
3e25ded
 
 
 
1c0aad9
 
 
 
 
 
3e25ded
1c0aad9
3e25ded
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
import os
import string
import secrets
import tempfile

from blossomtune_gradio import config as cfg
from blossomtune_gradio import mail
from blossomtune_gradio import util
from blossomtune_gradio.settings import settings
from blossomtune_gradio.database import SessionLocal, Request, Config
from blossomtune_gradio.auth_keys import AuthKeyGenerator, rebuild_authorized_keys_csv
from blossomtune_gradio.blossomfile import create_blossomfile


def generate_participant_id(length=6):
    """Generates a random, uppercase alphanumeric participant ID."""
    alphabet = string.ascii_uppercase + string.digits
    return "".join(secrets.choice(alphabet) for _ in range(length))


def generate_activation_code(length=8):
    """Generates a random, uppercase alphanumeric activation code."""
    alphabet = string.ascii_uppercase + string.digits
    return "".join(secrets.choice(alphabet) for _ in range(length))


def check_participant_status(pid_to_check: str, email: str, activation_code: str):
    """
    Handles a participant's request to join, activate, or check status using SQLAlchemy.
    Returns a tuple: (is_approved: bool, message: str, data: any | None)
    """
    with SessionLocal() as db:
        query = db.query(Request).filter(
            Request.hf_handle == pid_to_check, Request.email == email
        )
        if activation_code:
            query = query.filter(Request.activation_code == activation_code)

        request = query.first()

        num_partitions_config = (
            db.query(Config).filter(Config.key == "num_partitions").first()
        )
        num_partitions = num_partitions_config.value if num_partitions_config else "10"

        # Case 1 & 2 are for users not yet approved
        if not request or not request.is_activated or not activation_code:
            if request is None:
                if activation_code:
                    return (False, settings.get_text("activation_invalid_md"), None)
                if not util.validate_email(email):
                    return (False, settings.get_text("invalid_email_md"), None)
                approved_count = (
                    db.query(Request).filter(Request.status == "approved").count()
                )
                if approved_count >= cfg.MAX_NUM_NODES:
                    return (False, settings.get_text("federation_full_md"), None)
                participant_id = generate_participant_id()
                new_activation_code = generate_activation_code()
                mail_sent, message = mail.send_activation_email(
                    email, new_activation_code
                )
                if mail_sent:
                    new_request = Request(
                        participant_id=participant_id,
                        hf_handle=pid_to_check,
                        email=email,
                        activation_code=new_activation_code,
                    )
                    db.add(new_request)
                    db.commit()
                    return (False, settings.get_text("registration_submitted_md"), None)
                else:
                    return (False, message, None)

            if not request.is_activated:
                if activation_code == request.activation_code:
                    request.is_activated = 1
                    db.commit()
                    return (False, settings.get_text("activation_successful_md"), None)
                else:
                    return (False, settings.get_text("activation_invalid_md"), None)

            if not activation_code:
                return (False, settings.get_text("missing_activation_code_md"), None)

        # Case 3: Activated user is checking their final status
        if request.status == "approved":
            hostname = (
                "localhost"
                if not cfg.SPACE_ID
                else f"{cfg.SPACE_ID.split('/')[1]}-{cfg.SPACE_ID.split('/')[0]}.hf.space"
            )
            superlink_address = f"{cfg.SUPERLINK_HOST or hostname}:{cfg.SUPERLINK_PORT}"

            # Blossomfile Generation
            blossomfile_tempdir = tempfile.mkdtemp()  # TODO: remove tempdirs
            try:
                blossomfile_path = create_blossomfile(
                    participant_id=request.participant_id,
                    output_dir=blossomfile_tempdir,
                    ca_cert_path=cfg.BLOSSOMTUNE_TLS_CA_CERTFILE,
                    auth_key_path=os.path.join(
                        cfg.AUTH_KEYS_DIR, f"{request.participant_id}.key"
                    ),
                    auth_pub_path=os.path.join(
                        cfg.AUTH_KEYS_DIR, f"{request.participant_id}.pub"
                    ),
                    superlink_address=superlink_address,
                    partition_id=request.partition_id,
                    num_partitions=int(num_partitions),
                )
            except FileNotFoundError:
                return (False, "An error occurred.", None)

            connection_string = settings.get_text(
                "status_approved_md",
                participant_id=request.participant_id,
                partition_id=request.partition_id,
                superlink_hostname=superlink_address.split(":")[0],
                superlink_port=superlink_address.split(":")[1],
                num_partitions=num_partitions,
            )
            return (True, connection_string, blossomfile_path)
        elif request.status == "pending":
            return (False, settings.get_text("status_pending_md"), None)
        else:  # Denied
            return (
                False,
                settings.get_text(
                    "status_denied_md", participant_id=request.participant_id
                ),
                None,
            )


def manage_request(participant_id: str, partition_id: str, action: str):
    """Admin function to approve/deny a request and assign a partition ID."""
    if not participant_id:
        return False, "Please select a participant from the pending requests table."

    with SessionLocal() as db:
        request = (
            db.query(Request).filter(Request.participant_id == participant_id).first()
        )
        if not request:
            return False, "Participant not found."

        if action == "approve":
            if not partition_id or not partition_id.isdigit():
                return False, "Please provide a valid integer for the Partition ID."
            p_id_int = int(partition_id)
            if not request.is_activated:
                return (
                    False,
                    settings.get_text("participant_not_activated_warning_md"),
                )

            existing_participant = (
                db.query(Request)
                .filter(Request.partition_id == p_id_int, Request.status == "approved")
                .first()
            )
            if existing_participant:
                return (
                    False,
                    settings.get_text(
                        "partition_in_use_warning_md", partition_id=p_id_int
                    ),
                )

            request.status = "approved"
            request.partition_id = p_id_int

            # Generate and Store Auth Keys
            key_generator = AuthKeyGenerator(key_dir=cfg.AUTH_KEYS_DIR)
            _, _, public_key_pem = key_generator.generate_participant_keys(
                participant_id
            )
            request.public_key_pem = public_key_pem
            db.commit()

            # Rebuild Authorized Keys CSV
            approved_participants = (
                db.query(Request.participant_id, Request.public_key_pem)
                .filter(
                    Request.status == "approved", Request.public_key_pem.isnot(None)
                )
                .all()
            )
            rebuild_authorized_keys_csv(cfg.AUTH_KEYS_DIR, approved_participants)

            return (
                True,
                f"Participant {participant_id} approved. Keys generated and registry updated.",
            )
        else:  # Deny
            request.status = "denied"
            request.partition_id = None
            request.public_key_pem = None
            db.commit()

            # --- Rebuild CSV after denial to revoke access ---
            approved_participants = (
                db.query(Request.participant_id, Request.public_key_pem)
                .filter(
                    Request.status == "approved", Request.public_key_pem.isnot(None)
                )
                .all()
            )
            rebuild_authorized_keys_csv(cfg.AUTH_KEYS_DIR, approved_participants)

            return (
                True,
                f"Participant {participant_id} denied. Their access has been revoked.",
            )


def get_next_partion_id() -> int:
    """Finds the lowest available partition ID."""
    with SessionLocal() as db:
        used_ids_query = (
            db.query(Request.partition_id)
            .filter(Request.status == "approved", Request.partition_id.isnot(None))
            .all()
        )
        used_ids = {row[0] for row in used_ids_query}

    next_id = 0
    while next_id in used_ids:
        next_id += 1
    return next_id