Mark-Lasfar
commited on
Commit
·
a7d3cc7
1
Parent(s):
5e6c2a1
Fix ChunkedIteratorResult in SQLAlchemyUserDatabase and toggleBtn null error
Browse files- api/auth.py +36 -78
api/auth.py
CHANGED
|
@@ -16,7 +16,7 @@ import os
|
|
| 16 |
import logging
|
| 17 |
import secrets
|
| 18 |
|
| 19 |
-
#
|
| 20 |
logger = logging.getLogger(__name__)
|
| 21 |
|
| 22 |
cookie_transport = CookieTransport(cookie_max_age=3600)
|
|
@@ -35,23 +35,17 @@ auth_backend = AuthenticationBackend(
|
|
| 35 |
get_strategy=get_jwt_strategy,
|
| 36 |
)
|
| 37 |
|
| 38 |
-
# OAuth
|
| 39 |
GOOGLE_CLIENT_ID = os.getenv("GOOGLE_CLIENT_ID")
|
| 40 |
GOOGLE_CLIENT_SECRET = os.getenv("GOOGLE_CLIENT_SECRET")
|
| 41 |
GITHUB_CLIENT_ID = os.getenv("GITHUB_CLIENT_ID")
|
| 42 |
GITHUB_CLIENT_SECRET = os.getenv("GITHUB_CLIENT_SECRET")
|
| 43 |
|
| 44 |
-
#
|
| 45 |
-
logger.info("GOOGLE_CLIENT_ID is set: %s", bool(GOOGLE_CLIENT_ID))
|
| 46 |
-
logger.info("GOOGLE_CLIENT_SECRET is set: %s", bool(GOOGLE_CLIENT_SECRET))
|
| 47 |
-
logger.info("GITHUB_CLIENT_ID is set: %s", bool(GITHUB_CLIENT_ID))
|
| 48 |
-
logger.info("GITHUB_CLIENT_SECRET is set: %s", bool(GITHUB_CLIENT_SECRET))
|
| 49 |
-
|
| 50 |
if not all([GOOGLE_CLIENT_ID, GOOGLE_CLIENT_SECRET, GITHUB_CLIENT_ID, GITHUB_CLIENT_SECRET]):
|
| 51 |
logger.error("One or more OAuth environment variables are missing.")
|
| 52 |
-
raise ValueError("All OAuth credentials
|
| 53 |
|
| 54 |
-
# OAuth redirect URLs
|
| 55 |
GOOGLE_REDIRECT_URL = os.getenv("GOOGLE_REDIRECT_URL", "https://mgzon-mgzon-app.hf.space/auth/google/callback")
|
| 56 |
GITHUB_REDIRECT_URL = os.getenv("GITHUB_REDIRECT_URL", "https://mgzon-mgzon-app.hf.space/auth/github/callback")
|
| 57 |
|
|
@@ -60,70 +54,37 @@ github_oauth_client = GitHubOAuth2(GITHUB_CLIENT_ID, GITHUB_CLIENT_SECRET)
|
|
| 60 |
|
| 61 |
class CustomSQLAlchemyUserDatabase(SQLAlchemyUserDatabase):
|
| 62 |
async def get_by_email(self, email: str) -> Optional[User]:
|
| 63 |
-
"""Override to fix ChunkedIteratorResult issue."""
|
| 64 |
logger.info(f"Checking for user with email: {email}")
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
user = result.scalar_one_or_none()
|
| 69 |
-
if user:
|
| 70 |
-
logger.info(f"Found user with email: {email}")
|
| 71 |
-
else:
|
| 72 |
-
logger.info(f"No user found with email: {email}")
|
| 73 |
-
return user
|
| 74 |
-
except Exception as e:
|
| 75 |
-
logger.error(f"Error in get_by_email: {e}")
|
| 76 |
-
raise
|
| 77 |
|
| 78 |
async def create(self, create_dict: Dict[str, Any]) -> User:
|
| 79 |
-
"""Override to fix potential async issues."""
|
| 80 |
logger.info(f"Creating user with email: {create_dict.get('email')}")
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
logger.info(f"Created user with email: {create_dict.get('email')}")
|
| 87 |
-
return user
|
| 88 |
-
except Exception as e:
|
| 89 |
-
logger.error(f"Error creating user: {e}")
|
| 90 |
-
await self.session.rollback()
|
| 91 |
-
raise
|
| 92 |
|
| 93 |
class UserManager(IntegerIDMixin, BaseUserManager[User, int]):
|
| 94 |
reset_password_token_secret = SECRET
|
| 95 |
verification_token_secret = SECRET
|
| 96 |
|
| 97 |
async def get_by_oauth_account(self, oauth_name: str, account_id: str):
|
| 98 |
-
"
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
oauth_account = result.scalar_one_or_none()
|
| 106 |
-
if oauth_account:
|
| 107 |
-
logger.info(f"Found existing OAuth account for {account_id}")
|
| 108 |
-
else:
|
| 109 |
-
logger.info(f"No existing OAuth account found for {account_id}")
|
| 110 |
-
return oauth_account
|
| 111 |
-
except Exception as e:
|
| 112 |
-
logger.error(f"Error in get_by_oauth_account: {e}")
|
| 113 |
-
raise
|
| 114 |
|
| 115 |
async def add_oauth_account(self, oauth_account: OAuthAccount):
|
| 116 |
-
"""Override to fix potential async issues."""
|
| 117 |
logger.info(f"Adding OAuth account for user {oauth_account.user_id}")
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
await self.session.refresh(oauth_account)
|
| 122 |
-
logger.info(f"Successfully added OAuth account for user {oauth_account.user_id}")
|
| 123 |
-
except Exception as e:
|
| 124 |
-
logger.error(f"Error adding OAuth account: {e}")
|
| 125 |
-
await self.session.rollback()
|
| 126 |
-
raise
|
| 127 |
|
| 128 |
async def oauth_callback(
|
| 129 |
self,
|
|
@@ -138,27 +99,25 @@ class UserManager(IntegerIDMixin, BaseUserManager[User, int]):
|
|
| 138 |
associate_by_email: bool = False,
|
| 139 |
is_verified_by_default: bool = False,
|
| 140 |
) -> UP:
|
| 141 |
-
logger.info(f"OAuth callback for {oauth_name}
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
existing_oauth_account = await self.get_by_oauth_account(oauth_name, account_id)
|
| 152 |
-
if existing_oauth_account
|
| 153 |
-
logger.info(f"Existing account found, logging in user {existing_oauth_account.user.email}")
|
| 154 |
return await self.on_after_login(existing_oauth_account.user, request)
|
| 155 |
|
| 156 |
if associate_by_email:
|
| 157 |
user = await self.user_db.get_by_email(account_email)
|
| 158 |
-
if user
|
| 159 |
oauth_account.user_id = user.id
|
| 160 |
await self.add_oauth_account(oauth_account)
|
| 161 |
-
logger.info(f"Associated with existing user {user.email}")
|
| 162 |
return await self.on_after_login(user, request)
|
| 163 |
|
| 164 |
user_dict = {
|
|
@@ -170,9 +129,9 @@ class UserManager(IntegerIDMixin, BaseUserManager[User, int]):
|
|
| 170 |
user = await self.user_db.create(user_dict)
|
| 171 |
oauth_account.user_id = user.id
|
| 172 |
await self.add_oauth_account(oauth_account)
|
| 173 |
-
logger.info(f"Created new user {user.email}")
|
| 174 |
return await self.on_after_login(user, request)
|
| 175 |
|
|
|
|
| 176 |
async def get_user_manager(user_db: SQLAlchemyUserDatabase = Depends(get_user_db)):
|
| 177 |
yield UserManager(user_db)
|
| 178 |
|
|
@@ -184,7 +143,6 @@ google_oauth_router = get_oauth_router(
|
|
| 184 |
associate_by_email=True,
|
| 185 |
redirect_url=GOOGLE_REDIRECT_URL,
|
| 186 |
)
|
| 187 |
-
logger.info("Google OAuth router initialized successfully")
|
| 188 |
|
| 189 |
github_oauth_router = get_oauth_router(
|
| 190 |
github_oauth_client,
|
|
@@ -194,7 +152,6 @@ github_oauth_router = get_oauth_router(
|
|
| 194 |
associate_by_email=True,
|
| 195 |
redirect_url=GITHUB_REDIRECT_URL,
|
| 196 |
)
|
| 197 |
-
logger.info("GitHub OAuth router initialized successfully")
|
| 198 |
|
| 199 |
fastapi_users = FastAPIUsers[User, int](
|
| 200 |
get_user_db,
|
|
@@ -211,3 +168,4 @@ def get_auth_router(app: FastAPI):
|
|
| 211 |
app.include_router(fastapi_users.get_reset_password_router(), prefix="/auth", tags=["auth"])
|
| 212 |
app.include_router(fastapi_users.get_verify_router(UserRead), prefix="/auth", tags=["auth"])
|
| 213 |
app.include_router(fastapi_users.get_users_router(UserRead, UserUpdate), prefix="/users", tags=["users"])
|
|
|
|
|
|
| 16 |
import logging
|
| 17 |
import secrets
|
| 18 |
|
| 19 |
+
# إعداد اللوقينج
|
| 20 |
logger = logging.getLogger(__name__)
|
| 21 |
|
| 22 |
cookie_transport = CookieTransport(cookie_max_age=3600)
|
|
|
|
| 35 |
get_strategy=get_jwt_strategy,
|
| 36 |
)
|
| 37 |
|
| 38 |
+
# OAuth بيانات الاعتماد
|
| 39 |
GOOGLE_CLIENT_ID = os.getenv("GOOGLE_CLIENT_ID")
|
| 40 |
GOOGLE_CLIENT_SECRET = os.getenv("GOOGLE_CLIENT_SECRET")
|
| 41 |
GITHUB_CLIENT_ID = os.getenv("GITHUB_CLIENT_ID")
|
| 42 |
GITHUB_CLIENT_SECRET = os.getenv("GITHUB_CLIENT_SECRET")
|
| 43 |
|
| 44 |
+
# تحقق من توافر بيانات الاعتماد
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
if not all([GOOGLE_CLIENT_ID, GOOGLE_CLIENT_SECRET, GITHUB_CLIENT_ID, GITHUB_CLIENT_SECRET]):
|
| 46 |
logger.error("One or more OAuth environment variables are missing.")
|
| 47 |
+
raise ValueError("All OAuth credentials are required.")
|
| 48 |
|
|
|
|
| 49 |
GOOGLE_REDIRECT_URL = os.getenv("GOOGLE_REDIRECT_URL", "https://mgzon-mgzon-app.hf.space/auth/google/callback")
|
| 50 |
GITHUB_REDIRECT_URL = os.getenv("GITHUB_REDIRECT_URL", "https://mgzon-mgzon-app.hf.space/auth/github/callback")
|
| 51 |
|
|
|
|
| 54 |
|
| 55 |
class CustomSQLAlchemyUserDatabase(SQLAlchemyUserDatabase):
|
| 56 |
async def get_by_email(self, email: str) -> Optional[User]:
|
|
|
|
| 57 |
logger.info(f"Checking for user with email: {email}")
|
| 58 |
+
statement = select(self.user_table).where(self.user_table.email == email)
|
| 59 |
+
result = await self.session.execute(statement)
|
| 60 |
+
return result.scalar_one_or_none()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
|
| 62 |
async def create(self, create_dict: Dict[str, Any]) -> User:
|
|
|
|
| 63 |
logger.info(f"Creating user with email: {create_dict.get('email')}")
|
| 64 |
+
user = self.user_table(**create_dict)
|
| 65 |
+
self.session.add(user)
|
| 66 |
+
await self.session.commit()
|
| 67 |
+
await self.session.refresh(user)
|
| 68 |
+
return user
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
|
| 70 |
class UserManager(IntegerIDMixin, BaseUserManager[User, int]):
|
| 71 |
reset_password_token_secret = SECRET
|
| 72 |
verification_token_secret = SECRET
|
| 73 |
|
| 74 |
async def get_by_oauth_account(self, oauth_name: str, account_id: str):
|
| 75 |
+
logger.info(f"Checking OAuth account: {oauth_name}/{account_id}")
|
| 76 |
+
statement = select(OAuthAccount).where(
|
| 77 |
+
OAuthAccount.oauth_name == oauth_name,
|
| 78 |
+
OAuthAccount.account_id == account_id
|
| 79 |
+
)
|
| 80 |
+
result = await self.user_db.session.execute(statement)
|
| 81 |
+
return result.scalar_one_or_none()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
|
| 83 |
async def add_oauth_account(self, oauth_account: OAuthAccount):
|
|
|
|
| 84 |
logger.info(f"Adding OAuth account for user {oauth_account.user_id}")
|
| 85 |
+
self.user_db.session.add(oauth_account)
|
| 86 |
+
await self.user_db.session.commit()
|
| 87 |
+
await self.user_db.session.refresh(oauth_account)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 88 |
|
| 89 |
async def oauth_callback(
|
| 90 |
self,
|
|
|
|
| 99 |
associate_by_email: bool = False,
|
| 100 |
is_verified_by_default: bool = False,
|
| 101 |
) -> UP:
|
| 102 |
+
logger.info(f"OAuth callback for {oauth_name} account {account_id}")
|
| 103 |
+
|
| 104 |
+
oauth_account = OAuthAccount(
|
| 105 |
+
oauth_name=oauth_name,
|
| 106 |
+
access_token=access_token,
|
| 107 |
+
account_id=account_id,
|
| 108 |
+
account_email=account_email,
|
| 109 |
+
expires_at=expires_at,
|
| 110 |
+
refresh_token=refresh_token,
|
| 111 |
+
)
|
| 112 |
existing_oauth_account = await self.get_by_oauth_account(oauth_name, account_id)
|
| 113 |
+
if existing_oauth_account:
|
|
|
|
| 114 |
return await self.on_after_login(existing_oauth_account.user, request)
|
| 115 |
|
| 116 |
if associate_by_email:
|
| 117 |
user = await self.user_db.get_by_email(account_email)
|
| 118 |
+
if user:
|
| 119 |
oauth_account.user_id = user.id
|
| 120 |
await self.add_oauth_account(oauth_account)
|
|
|
|
| 121 |
return await self.on_after_login(user, request)
|
| 122 |
|
| 123 |
user_dict = {
|
|
|
|
| 129 |
user = await self.user_db.create(user_dict)
|
| 130 |
oauth_account.user_id = user.id
|
| 131 |
await self.add_oauth_account(oauth_account)
|
|
|
|
| 132 |
return await self.on_after_login(user, request)
|
| 133 |
|
| 134 |
+
# استدعاء user manager من get_user_db (تجنب التكرار)
|
| 135 |
async def get_user_manager(user_db: SQLAlchemyUserDatabase = Depends(get_user_db)):
|
| 136 |
yield UserManager(user_db)
|
| 137 |
|
|
|
|
| 143 |
associate_by_email=True,
|
| 144 |
redirect_url=GOOGLE_REDIRECT_URL,
|
| 145 |
)
|
|
|
|
| 146 |
|
| 147 |
github_oauth_router = get_oauth_router(
|
| 148 |
github_oauth_client,
|
|
|
|
| 152 |
associate_by_email=True,
|
| 153 |
redirect_url=GITHUB_REDIRECT_URL,
|
| 154 |
)
|
|
|
|
| 155 |
|
| 156 |
fastapi_users = FastAPIUsers[User, int](
|
| 157 |
get_user_db,
|
|
|
|
| 168 |
app.include_router(fastapi_users.get_reset_password_router(), prefix="/auth", tags=["auth"])
|
| 169 |
app.include_router(fastapi_users.get_verify_router(UserRead), prefix="/auth", tags=["auth"])
|
| 170 |
app.include_router(fastapi_users.get_users_router(UserRead, UserUpdate), prefix="/users", tags=["users"])
|
| 171 |
+
|