Spaces:
Running
Running
Expand colab notebook
Browse files- examples/pysr_demo.ipynb +71 -8
examples/pysr_demo.ipynb
CHANGED
|
@@ -1262,6 +1262,7 @@
|
|
| 1262 |
]
|
| 1263 |
},
|
| 1264 |
{
|
|
|
|
| 1265 |
"cell_type": "markdown",
|
| 1266 |
"metadata": {
|
| 1267 |
"id": "nCCIvvAGuyFi"
|
|
@@ -1269,7 +1270,60 @@
|
|
| 1269 |
"source": [
|
| 1270 |
"## Learning over the network:\n",
|
| 1271 |
"\n",
|
| 1272 |
-
"Now, let's fit `g` using PySR
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1273 |
]
|
| 1274 |
},
|
| 1275 |
{
|
|
@@ -1281,17 +1335,15 @@
|
|
| 1281 |
},
|
| 1282 |
"outputs": [],
|
| 1283 |
"source": [
|
| 1284 |
-
"np.random.
|
| 1285 |
-
"
|
| 1286 |
-
"tmpy = y_i_for_pysr.detach().numpy().reshape(-1)\n",
|
| 1287 |
-
"idx2 = np.random.randint(0, tmpy.shape[0], size=500)\n",
|
| 1288 |
"\n",
|
| 1289 |
"model = PySRRegressor(\n",
|
| 1290 |
" niterations=20,\n",
|
| 1291 |
" binary_operators=[\"plus\", \"sub\", \"mult\"],\n",
|
| 1292 |
" unary_operators=[\"cos\", \"square\", \"neg\"],\n",
|
| 1293 |
")\n",
|
| 1294 |
-
"model.fit(
|
| 1295 |
]
|
| 1296 |
},
|
| 1297 |
{
|
|
@@ -1310,9 +1362,12 @@
|
|
| 1310 |
"id": "6WuaeqyqbDhe"
|
| 1311 |
},
|
| 1312 |
"source": [
|
| 1313 |
-
"Recall we are searching for $
|
|
|
|
|
|
|
|
|
|
| 1314 |
"\n",
|
| 1315 |
-
"
|
| 1316 |
]
|
| 1317 |
},
|
| 1318 |
{
|
|
@@ -1384,7 +1439,15 @@
|
|
| 1384 |
"name": "main_ipynb"
|
| 1385 |
},
|
| 1386 |
"language_info": {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1387 |
"name": "python",
|
|
|
|
|
|
|
| 1388 |
"version": "3.10.9"
|
| 1389 |
}
|
| 1390 |
},
|
|
|
|
| 1262 |
]
|
| 1263 |
},
|
| 1264 |
{
|
| 1265 |
+
"attachments": {},
|
| 1266 |
"cell_type": "markdown",
|
| 1267 |
"metadata": {
|
| 1268 |
"id": "nCCIvvAGuyFi"
|
|
|
|
| 1270 |
"source": [
|
| 1271 |
"## Learning over the network:\n",
|
| 1272 |
"\n",
|
| 1273 |
+
"Now, let's fit `g` using PySR.\n",
|
| 1274 |
+
"\n",
|
| 1275 |
+
"> **Warning**\n",
|
| 1276 |
+
">\n",
|
| 1277 |
+
"> First, let's save the data, because sometimes PyTorch and PyJulia's C bindings interfere and cause the colab kernel to crash. If we need to restart, we can just load the data without having to retrain the network:"
|
| 1278 |
+
]
|
| 1279 |
+
},
|
| 1280 |
+
{
|
| 1281 |
+
"cell_type": "code",
|
| 1282 |
+
"execution_count": null,
|
| 1283 |
+
"metadata": {},
|
| 1284 |
+
"outputs": [],
|
| 1285 |
+
"source": [
|
| 1286 |
+
"nnet_recordings = {\n",
|
| 1287 |
+
" \"g_input\": X_for_pysr.detach().cpu().numpy().reshape(-1, 5),\n",
|
| 1288 |
+
" \"g_output\": y_i_for_pysr.detach().cpu().numpy().reshape(-1),\n",
|
| 1289 |
+
" \"f_input\": y_for_pysr.detach().cpu().numpy().reshape(-1, 1),\n",
|
| 1290 |
+
" \"f_output\": z_for_pysr.detach().cpu().numpy().reshape(-1),\n",
|
| 1291 |
+
"}\n",
|
| 1292 |
+
"\n",
|
| 1293 |
+
"# Save the data for later use:\n",
|
| 1294 |
+
"import pickle as pkl\n",
|
| 1295 |
+
"\n",
|
| 1296 |
+
"with open(\"nnet_recordings.pkl\", \"wb\") as f:\n",
|
| 1297 |
+
" pkl.dump(nnet_recordings, f)"
|
| 1298 |
+
]
|
| 1299 |
+
},
|
| 1300 |
+
{
|
| 1301 |
+
"attachments": {},
|
| 1302 |
+
"cell_type": "markdown",
|
| 1303 |
+
"metadata": {},
|
| 1304 |
+
"source": [
|
| 1305 |
+
"We can now load the data:"
|
| 1306 |
+
]
|
| 1307 |
+
},
|
| 1308 |
+
{
|
| 1309 |
+
"cell_type": "code",
|
| 1310 |
+
"execution_count": null,
|
| 1311 |
+
"metadata": {},
|
| 1312 |
+
"outputs": [],
|
| 1313 |
+
"source": [
|
| 1314 |
+
"nnet_recordings = pkl.load(open(\"nnet_recordings.pkl\", \"rb\"))\n",
|
| 1315 |
+
"f_input = nnet_recordings[\"f_input\"]\n",
|
| 1316 |
+
"f_output = nnet_recordings[\"f_output\"]\n",
|
| 1317 |
+
"g_input = nnet_recordings[\"g_input\"]\n",
|
| 1318 |
+
"g_output = nnet_recordings[\"g_output\"]"
|
| 1319 |
+
]
|
| 1320 |
+
},
|
| 1321 |
+
{
|
| 1322 |
+
"attachments": {},
|
| 1323 |
+
"cell_type": "markdown",
|
| 1324 |
+
"metadata": {},
|
| 1325 |
+
"source": [
|
| 1326 |
+
"And now fit using a subsample of the data (symbolic regression only needs a small sample to find the best equation):"
|
| 1327 |
]
|
| 1328 |
},
|
| 1329 |
{
|
|
|
|
| 1335 |
},
|
| 1336 |
"outputs": [],
|
| 1337 |
"source": [
|
| 1338 |
+
"rstate = np.random.RandomState(0)\n",
|
| 1339 |
+
"f_sample_idx = rstate.choice(f_input.shape[0], size=500, replace=False)\n",
|
|
|
|
|
|
|
| 1340 |
"\n",
|
| 1341 |
"model = PySRRegressor(\n",
|
| 1342 |
" niterations=20,\n",
|
| 1343 |
" binary_operators=[\"plus\", \"sub\", \"mult\"],\n",
|
| 1344 |
" unary_operators=[\"cos\", \"square\", \"neg\"],\n",
|
| 1345 |
")\n",
|
| 1346 |
+
"model.fit(g_input[f_sample_idx], g_output[f_sample_idx])"
|
| 1347 |
]
|
| 1348 |
},
|
| 1349 |
{
|
|
|
|
| 1362 |
"id": "6WuaeqyqbDhe"
|
| 1363 |
},
|
| 1364 |
"source": [
|
| 1365 |
+
"Recall we are searching for $f$ and $g$ such that:\n",
|
| 1366 |
+
"$$z=f(\\sum g(x_i))$$ \n",
|
| 1367 |
+
"which approximates the true relation:\n",
|
| 1368 |
+
"$$ z = y^2,\\quad y = \\frac{1}{10} \\sum(y_i),\\quad y_i = x_{i0}^2 + 6 \\cos(2 x_{i2})$$\n",
|
| 1369 |
"\n",
|
| 1370 |
+
"Let's see how well we did in recovering $g$:"
|
| 1371 |
]
|
| 1372 |
},
|
| 1373 |
{
|
|
|
|
| 1439 |
"name": "main_ipynb"
|
| 1440 |
},
|
| 1441 |
"language_info": {
|
| 1442 |
+
"codemirror_mode": {
|
| 1443 |
+
"name": "ipython",
|
| 1444 |
+
"version": 3
|
| 1445 |
+
},
|
| 1446 |
+
"file_extension": ".py",
|
| 1447 |
+
"mimetype": "text/x-python",
|
| 1448 |
"name": "python",
|
| 1449 |
+
"nbconvert_exporter": "python",
|
| 1450 |
+
"pygments_lexer": "ipython3",
|
| 1451 |
"version": "3.10.9"
|
| 1452 |
}
|
| 1453 |
},
|