mrs83 commited on
Commit
e60e679
·
1 Parent(s): 4a4c0cf

add more tests, remove warning

Browse files
blossomtune_gradio/util.py CHANGED
@@ -22,7 +22,7 @@ def validate_email(email_address: str) -> bool:
22
  # DNS MX record validation
23
  try:
24
  domain = email_address.rsplit("@", 1)[-1]
25
- dns.resolver.query(domain, "MX")
26
  return True
27
  except (dns.resolver.NoAnswer, dns.resolver.NXDOMAIN, IndexError):
28
  return False
 
22
  # DNS MX record validation
23
  try:
24
  domain = email_address.rsplit("@", 1)[-1]
25
+ dns.resolver.resolve(domain, "MX")
26
  return True
27
  except (dns.resolver.NoAnswer, dns.resolver.NXDOMAIN, IndexError):
28
  return False
tests/test_federation.py ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sqlite3
2
+ import pytest
3
+ from unittest.mock import MagicMock
4
+
5
+ from blossomtune_gradio import federation as fed
6
+ from blossomtune_gradio import database
7
+
8
+
9
+ @pytest.fixture
10
+ def in_memory_db(mocker):
11
+ """
12
+ Fixture to set up and tear down an in-memory SQLite database for tests.
13
+ It ensures that the same connection object is used for both schema
14
+ initialization and the test execution.
15
+ """
16
+ con = sqlite3.connect(":memory:")
17
+ mocker.patch("sqlite3.connect", return_value=con)
18
+ database.init()
19
+ yield con
20
+ con.close()
21
+
22
+
23
+ @pytest.fixture
24
+ def mock_settings(mocker):
25
+ """Fixture to mock the settings module."""
26
+ # The lambda returns a formatted string to simulate Jinja2's behavior
27
+ mock_get = MagicMock(
28
+ side_effect=lambda key, **kwargs: f"mock_{key}".format(**kwargs)
29
+ )
30
+ mocker.patch("blossomtune_gradio.federation.settings.get_text", mock_get)
31
+ return mock_get
32
+
33
+
34
+ @pytest.fixture
35
+ def mock_mail(mocker):
36
+ """Fixture to mock the mail module."""
37
+ return mocker.patch("blossomtune_gradio.mail.send_activation_email")
38
+
39
+
40
+ def test_generate_participant_id():
41
+ """Test the generation of a participant ID."""
42
+ pid = fed.generate_participant_id()
43
+ assert isinstance(pid, str)
44
+ assert len(pid) == 6
45
+ assert pid.isalnum() and pid.isupper()
46
+
47
+
48
+ def test_generate_activation_code():
49
+ """Test the generation of an activation code."""
50
+ code = fed.generate_activation_code()
51
+ assert isinstance(code, str)
52
+ assert len(code) == 8
53
+ assert code.isalnum() and code.isupper()
54
+
55
+
56
+ class TestCheckParticipantStatus:
57
+ """Test suite for the check_participant_status function."""
58
+
59
+ def test_new_user_registration_success(
60
+ self, in_memory_db, mock_settings, mock_mail
61
+ ):
62
+ """Verify successful registration for a new user."""
63
+ mock_mail.return_value = (True, "")
64
+ success, message = fed.check_participant_status(
65
+ "new_user", "new@example.com", ""
66
+ )
67
+ assert success is True
68
+ assert message == "mock_registration_submitted_md"
69
+
70
+ # Verify the user was added to the database
71
+ cursor = in_memory_db.cursor()
72
+ cursor.execute("SELECT hf_handle FROM requests WHERE hf_handle = 'new_user'")
73
+ assert cursor.fetchone() is not None
74
+
75
+ def test_new_user_invalid_email(self, in_memory_db, mock_settings):
76
+ """Verify registration fails with an invalid email."""
77
+ success, message = fed.check_participant_status("user", "invalid-email", "")
78
+ assert success is False
79
+ assert message == "mock_invalid_email_md"
80
+
81
+ def test_new_user_federation_full(self, in_memory_db, mock_settings, mocker):
82
+ """Verify registration fails when the federation is full."""
83
+ mocker.patch("blossomtune_gradio.federation.cfg.MAX_NUM_NODES", 0)
84
+ success, message = fed.check_participant_status(
85
+ "another_user", "another@example.com", ""
86
+ )
87
+ assert success is False
88
+ assert message == "mock_federation_full_md"
89
+
90
+ def test_user_activation_success(self, in_memory_db, mock_settings):
91
+ """Verify a user can successfully activate their account."""
92
+ # Setup: Add a pending, non-activated user
93
+ cursor = in_memory_db.cursor()
94
+ cursor.execute(
95
+ "INSERT INTO requests (participant_id, status, timestamp, hf_handle, email, activation_code, is_activated) VALUES (?, ?, ?, ?, ?, ?, ?)",
96
+ ("PID123", "pending", "now", "test_user", "test@example.com", "ABCDEF", 0),
97
+ )
98
+ in_memory_db.commit()
99
+
100
+ success, message = fed.check_participant_status(
101
+ "test_user", "test@example.com", "ABCDEF"
102
+ )
103
+ assert success is True
104
+ assert message == "mock_activation_successful_md"
105
+ # Verify the user is now activated
106
+ cursor.execute(
107
+ "SELECT is_activated FROM requests WHERE hf_handle = 'test_user'"
108
+ )
109
+ assert cursor.fetchone()[0] == 1
110
+
111
+ def test_user_activation_invalid_code(self, in_memory_db, mock_settings):
112
+ """Verify activation fails with an invalid code."""
113
+ # Setup: Add a pending, non-activated user
114
+ cursor = in_memory_db.cursor()
115
+ cursor.execute(
116
+ "INSERT INTO requests (participant_id, status, timestamp, hf_handle, email, activation_code, is_activated) VALUES (?, ?, ?, ?, ?, ?, ?)",
117
+ ("PID123", "pending", "now", "test_user", "test@example.com", "ABCDEF", 0),
118
+ )
119
+ in_memory_db.commit()
120
+
121
+ success, message = fed.check_participant_status(
122
+ "test_user", "test@example.com", "WRONGCODE"
123
+ )
124
+ assert success is False
125
+ assert message == "mock_activation_invalid_md"
126
+
127
+ def test_status_check_approved(self, in_memory_db, mock_settings):
128
+ """Verify the status check for an approved user."""
129
+ cursor = in_memory_db.cursor()
130
+ cursor.execute(
131
+ "INSERT INTO requests (participant_id, status, timestamp, hf_handle, email, activation_code, is_activated, partition_id) VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
132
+ (
133
+ "PID456",
134
+ "approved",
135
+ "now",
136
+ "approved_user",
137
+ "approved@example.com",
138
+ "GHIJKL",
139
+ 1,
140
+ 5,
141
+ ),
142
+ )
143
+ in_memory_db.commit()
144
+ success, message = fed.check_participant_status(
145
+ "approved_user", "approved@example.com", "GHIJKL"
146
+ )
147
+ assert success is True
148
+ assert "mock_status_approved_md" in message
149
+
150
+
151
+ class TestManageRequest:
152
+ """Test suite for the manage_request function."""
153
+
154
+ def test_approve_success(self, in_memory_db):
155
+ """Verify successful approval of a participant."""
156
+ # Setup: Add an activated, pending user
157
+ cursor = in_memory_db.cursor()
158
+ cursor.execute(
159
+ "INSERT INTO requests (participant_id, status, timestamp, hf_handle, email, activation_code, is_activated) VALUES (?, ?, ?, ?, ?, ?, ?)",
160
+ (
161
+ "PENDING1",
162
+ "pending",
163
+ "now",
164
+ "pending_user",
165
+ "pending@example.com",
166
+ "CODE",
167
+ 1,
168
+ ),
169
+ )
170
+ in_memory_db.commit()
171
+
172
+ success, message = fed.manage_request("PENDING1", "10", "approve")
173
+ assert success is True
174
+ assert "is allowed to join" in message
175
+
176
+ # Verify status in DB
177
+ cursor.execute(
178
+ "SELECT status, partition_id FROM requests WHERE participant_id = 'PENDING1'"
179
+ )
180
+ status, partition_id = cursor.fetchone()
181
+ assert status == "approved"
182
+ assert partition_id == 10
183
+
184
+ def test_approve_not_activated(self, in_memory_db, mock_settings):
185
+ """Verify approval fails if the user is not activated."""
186
+ cursor = in_memory_db.cursor()
187
+ cursor.execute(
188
+ "INSERT INTO requests (participant_id, status, timestamp, hf_handle, email, activation_code, is_activated) VALUES (?, ?, ?, ?, ?, ?, ?)",
189
+ (
190
+ "PENDING2",
191
+ "pending",
192
+ "now",
193
+ "pending_user2",
194
+ "pending2@example.com",
195
+ "CODE",
196
+ 0,
197
+ ),
198
+ )
199
+ in_memory_db.commit()
200
+ success, message = fed.manage_request("PENDING2", "11", "approve")
201
+ assert success is False
202
+ assert message == "mock_participant_not_activated_warning_md"
203
+
204
+ def test_deny_success(self, in_memory_db):
205
+ """Verify successful denial of a participant."""
206
+ cursor = in_memory_db.cursor()
207
+ cursor.execute(
208
+ "INSERT INTO requests (participant_id, status, timestamp, hf_handle, email, activation_code, is_activated) VALUES (?, ?, ?, ?, ?, ?, ?)",
209
+ (
210
+ "PENDING3",
211
+ "pending",
212
+ "now",
213
+ "pending_user3",
214
+ "pending3@example.com",
215
+ "CODE",
216
+ 1,
217
+ ),
218
+ )
219
+ in_memory_db.commit()
220
+ success, message = fed.manage_request("PENDING3", "", "deny")
221
+ assert success is True
222
+ assert "is not allowed to join" in message
223
+
224
+ # Verify status in DB
225
+ cursor.execute("SELECT status FROM requests WHERE participant_id = 'PENDING3'")
226
+ assert cursor.fetchone()[0] == "denied"
227
+
228
+
229
+ def test_get_next_partition_id(in_memory_db):
230
+ """Verify the logic for finding the next available partition ID."""
231
+ cursor = in_memory_db.cursor()
232
+ # No approved users yet
233
+ assert fed.get_next_partion_id() == 0
234
+
235
+ # Add some approved users with assigned partitions, including the required timestamp
236
+ cursor.execute(
237
+ "INSERT INTO requests (participant_id, status, timestamp, partition_id) VALUES (?, ?, ?, ?)",
238
+ ("P1", "approved", "now", 0),
239
+ )
240
+ cursor.execute(
241
+ "INSERT INTO requests (participant_id, status, timestamp, partition_id) VALUES (?, ?, ?, ?)",
242
+ ("P2", "approved", "now", 1),
243
+ )
244
+ in_memory_db.commit()
245
+ assert fed.get_next_partion_id() == 2
246
+
247
+ # Add another user, skipping a partition ID
248
+ cursor.execute(
249
+ "INSERT INTO requests (participant_id, status, timestamp, partition_id) VALUES (?, ?, ?, ?)",
250
+ ("P3", "approved", "now", 3),
251
+ )
252
+ in_memory_db.commit()
253
+ assert fed.get_next_partion_id() == 2
tests/test_util.py CHANGED
@@ -9,7 +9,7 @@ def test_validate_email_valid(monkeypatch):
9
  """Tests a syntactically valid email with an existing MX record."""
10
  # Mock the dns.resolver.query to return a successful result.
11
  mock_query = MagicMock()
12
- monkeypatch.setattr(dns.resolver, "query", mock_query)
13
 
14
  email = "test@google.com"
15
  assert validate_email(email) is True
@@ -28,7 +28,7 @@ def test_validate_email_no_mx_record(monkeypatch):
28
  """Tests a domain that exists but has no MX record."""
29
  # Mock the dns.resolver.query to raise a NoAnswer exception.
30
  mock_query = MagicMock(side_effect=dns.resolver.NoAnswer)
31
- monkeypatch.setattr(dns.resolver, "query", mock_query)
32
 
33
  email = "user@example.com"
34
  assert validate_email(email) is False
@@ -39,7 +39,7 @@ def test_validate_email_non_existent_domain(monkeypatch):
39
  """Tests a domain that does not exist."""
40
  # Mock the dns.resolver.query to raise an NXDOMAIN exception.
41
  mock_query = MagicMock(side_effect=dns.resolver.NXDOMAIN)
42
- monkeypatch.setattr(dns.resolver, "query", mock_query)
43
 
44
  email = "user@not-a-real-domain-123.com"
45
  assert validate_email(email) is False
 
9
  """Tests a syntactically valid email with an existing MX record."""
10
  # Mock the dns.resolver.query to return a successful result.
11
  mock_query = MagicMock()
12
+ monkeypatch.setattr(dns.resolver, "resolve", mock_query)
13
 
14
  email = "test@google.com"
15
  assert validate_email(email) is True
 
28
  """Tests a domain that exists but has no MX record."""
29
  # Mock the dns.resolver.query to raise a NoAnswer exception.
30
  mock_query = MagicMock(side_effect=dns.resolver.NoAnswer)
31
+ monkeypatch.setattr(dns.resolver, "resolve", mock_query)
32
 
33
  email = "user@example.com"
34
  assert validate_email(email) is False
 
39
  """Tests a domain that does not exist."""
40
  # Mock the dns.resolver.query to raise an NXDOMAIN exception.
41
  mock_query = MagicMock(side_effect=dns.resolver.NXDOMAIN)
42
+ monkeypatch.setattr(dns.resolver, "resolve", mock_query)
43
 
44
  email = "user@not-a-real-domain-123.com"
45
  assert validate_email(email) is False