from argparse import _SubParsersAction
from functools import partial
from typing import List, Optional
from lxml.objectify import BoolElement, IntElement, StringElement
from q2_sdk.core.dynamic_imports import (
api_ExecuteStoredProcedure as ExecuteStoredProcedure,
)
from q2_sdk.core.cli import textui
from q2_sdk.core.cli.cli_tools import SelectMenu, MenuOption
from q2_sdk.tools.decorators import dev_only
from .db_object import DbObject
from .group import Group
from .wedge_address import WedgeAddress
from .representation_row_base import RepresentationRowBase
[docs]
class MfaProviderRow(RepresentationRowBase):
MFAProviderID: IntElement = "MFAProviderID"
ShortName: StringElement = "ShortName"
RegistrationRequired: BoolElement = "RegistrationRequired"
WedgeAddressID: IntElement = "WedgeAddressID"
TokenLifetimeInMinutes: IntElement = "TokenLifetimeInMinutes"
[docs]
class MfaRegistrationRow(RepresentationRowBase):
MFARegistrationID: IntElement = "MFARegistrationID"
MFAProviderID: IntElement = "MFAProviderID"
UserID: IntElement = "UserID"
RegistrationValue: StringElement = "RegistrationValue"
CreateDate: StringElement = "CreateDate"
UpdatedDate: StringElement = "UpdatedDate"
DeletedDate: StringElement = "DeletedDate"
[docs]
class MfaGroupProfileRow(RepresentationRowBase):
MFAGroupProfileID: IntElement = "MFAGroupProfileID"
Description: StringElement = "Description"
AuthProviderID: IntElement = "AuthProviderID"
AuthProviderShortName: StringElement = "AuthProviderShortName"
TranAuthProviderID: IntElement = "TranAuthProviderID"
TranAuthProviderShortName: StringElement = "TranAuthProviderShortName"
EDVPatrolProviderID: IntElement = "EDVPatrolProviderID"
EDVPatrolProviderShortName: StringElement = "EDVPatrolProviderShortName"
[docs]
class GroupToMfaGroupProfileRow(RepresentationRowBase):
MFAGroupProfileID: IntElement = "GroupToMfaGroupProfileID"
ProfileDescription: StringElement = "ProfileDescription"
GroupID: IntElement = "GroupID"
[docs]
class Mfa(DbObject):
NAME = "Mfa"
[docs]
def add_arguments(self, parser: _SubParsersAction):
subparser = parser.add_parser("get_mfa_providers")
subparser.set_defaults(parser="get_mfa_providers")
subparser.set_defaults(func=partial(self.get_providers, serialize_for_cli=True))
subparser = parser.add_parser("get_mfa_group_profiles")
subparser.set_defaults(parser="get_mfa_group_profiles")
subparser.set_defaults(
func=partial(self.get_mfa_group_profiles, serialize_for_cli=True)
)
subparser = parser.add_parser("get_mfa_group_profiles_for_provider")
subparser.set_defaults(parser="get_mfa_group_profiles_for_provider")
subparser.set_defaults(
func=partial(
self.get_mfa_group_profiles_for_provider, serialize_for_cli=True
)
)
subparser.add_argument("provider_name", help="Q2_MFAProvider.ShortName")
subparser = parser.add_parser("get_group_to_mfa_group_profile")
subparser.set_defaults(parser="get_group_to_mfa_group_profile")
subparser.set_defaults(
func=partial(self.get_group_to_mfa_group_profile, serialize_for_cli=True)
)
subparser = parser.add_parser("get_group_to_mfa_group_profile_for_profile")
subparser.set_defaults(parser="get_group_to_mfa_group_profile_for_profile")
subparser.set_defaults(
func=partial(
self.get_group_to_mfa_group_profile_for_profile, serialize_for_cli=True
)
)
subparser = parser.add_parser("add_mfa_provider")
subparser.set_defaults(parser="add_mfa_provider")
subparser.set_defaults(func=partial(self.add_provider))
subparser.add_argument("short_name", help="Q2_MFAProvider.ShortName")
subparser.add_argument(
"wedge_address_short_name", help="Q2_WedgeAddress.ShortName"
)
subparser.add_argument(
"-t",
"--token-lifetime-in-minutes",
default=5,
help="Q2_MFAProvider.TokenLifetimeInMinutes (default 5)",
)
subparser.add_argument(
"-r",
"--registration-required",
action="store_true",
default=False,
help="Q2_MFAProvider.RegistrationRequired (default False)",
)
subparser = parser.add_parser("add_mfa_group_profile")
subparser.set_defaults(parser="add_mfa_group_profile")
subparser.set_defaults(func=partial(self._prompt_for_add_group_profile))
subparser = parser.add_parser("remove_mfa_provider")
subparser.set_defaults(parser="remove_mfa_provider")
subparser.set_defaults(func=partial(self.remove_provider))
subparser.add_argument("short_name", help="Q2_MFAProvider.ShortName")
subparser = parser.add_parser("remove_mfa_group_profile")
subparser.set_defaults(parser="remove_mfa_group_profile")
subparser.set_defaults(func=partial(self.remove_group_profile))
subparser.add_argument(
"mfa_group_profile_id", help="Q2_MFAGroupProfile.MFAGroupProfileID"
)
subparser = parser.add_parser("update_group_to_mfa_group_profile")
subparser.set_defaults(parser="update_group_to_mfa_group_profile")
subparser.set_defaults(func=partial(self.update_group_to_mfa_group_profile))
[docs]
async def get_group_to_mfa_group_profile(
self, serialize_for_cli=False
) -> List[GroupToMfaGroupProfileRow]:
response = await self.call_hq(
"sdk_GetGroupToMFAGroupProfile",
representation_class_override=GroupToMfaGroupProfileRow,
)
if serialize_for_cli:
columns = ["MFAGroupProfileID", "ProfileDescription", "GroupID"]
response = self.serialize_for_cli(response, columns)
return response
[docs]
async def get_group_to_mfa_group_profile_for_profile(
self, profile_id: int, serialize_for_cli=False
) -> List[GroupToMfaGroupProfileRow]:
response = await self.call_hq(
"sdk_GetGroupToMFAGroupProfileForProfile",
sql_parameters=ExecuteStoredProcedure.SqlParameters([
ExecuteStoredProcedure.SqlParam(
ExecuteStoredProcedure.DataType.Int,
"profile_id",
profile_id,
),
]),
representation_class_override=GroupToMfaGroupProfileRow,
)
if serialize_for_cli:
columns = ["MFAGroupProfileID", "ProfileDescription", "GroupID"]
response = self.serialize_for_cli(response, columns)
return response
[docs]
async def get_providers(
self,
serialize_for_cli=False,
short_name: Optional[str] = None,
) -> List[MfaProviderRow]:
response = await self.call_hq(
"sdk_GetMFAProviders", representation_class_override=MfaProviderRow
)
if short_name:
response = [x for x in response if x.ShortName == short_name]
if serialize_for_cli:
columns = [
"MFAProviderID",
"ShortName",
"RegistrationRequired",
"WedgeAddressID",
"TokenLifetimeInMinutes",
]
response = self.serialize_for_cli(response, columns)
return response
[docs]
async def get_mfa_group_profiles(
self, serialize_for_cli=False
) -> List[MfaGroupProfileRow]:
response = await self.call_hq(
"sdk_GetMFAGroupProfiles", representation_class_override=MfaGroupProfileRow
)
if serialize_for_cli:
columns = [
"MFAGroupProfileID",
"Description",
"AuthProviderID",
"AuthProviderShortName",
"TranAuthProviderID",
"TranAuthProviderShortName",
"EDVPatrolProviderID",
"EDVPatrolProviderShortName",
]
response = self.serialize_for_cli(response, columns)
return response
[docs]
async def get_mfa_group_profiles_for_provider(
self, provider_name: str, serialize_for_cli=False
) -> List[MfaGroupProfileRow]:
response = await self.call_hq(
"sdk_GetMFAGroupProfilesByProviderName",
sql_parameters=ExecuteStoredProcedure.SqlParameters([
ExecuteStoredProcedure.SqlParam(
ExecuteStoredProcedure.DataType.VarChar,
"provider_name",
provider_name,
),
]),
representation_class_override=MfaGroupProfileRow,
)
if serialize_for_cli:
columns = [
"MFAGroupProfileID",
"Description",
"AuthProviderID",
"AuthProviderShortName",
"TranAuthProviderID",
"TranAuthProviderShortName",
"EDVPatrolProviderID",
"EDVPatrolProviderShortName",
]
response = self.serialize_for_cli(response, columns)
return response
[docs]
async def get_registrations(
self, serialize_for_cli=False
) -> List[MfaRegistrationRow]:
response = await self.call_hq(
"sdk_GetMFARegistrations", representation_class_override=MfaRegistrationRow
)
if serialize_for_cli:
columns = [
"MFARegistrationID",
"ProviderName",
"UserID",
"RegistrationValue",
"CreateDate",
"UpdatedDate",
"DeletedDate",
]
response = self.serialize_for_cli(response, columns)
return response
[docs]
async def get_registrations_for_provider(
self, provider_name: str, serialize_for_cli=False
) -> List[MfaRegistrationRow]:
response = await self.call_hq(
"sdk_GetMFARegistrationsByProviderName",
sql_parameters=ExecuteStoredProcedure.SqlParameters([
ExecuteStoredProcedure.SqlParam(
ExecuteStoredProcedure.DataType.VarChar,
"provider_name",
provider_name,
),
]),
representation_class_override=MfaRegistrationRow,
)
if serialize_for_cli:
columns = [
"MFARegistrationID",
"ProviderName",
"UserID",
"RegistrationValue",
"CreateDate",
"UpdatedDate",
"DeletedDate",
]
response = self.serialize_for_cli(response, columns)
return response
[docs]
async def add_provider(
self,
short_name: str,
wedge_address_short_name: str,
registration_required: bool,
token_lifetime_in_minutes=5,
) -> bool:
wa_obj = WedgeAddress(
self.logger, hq_credentials=self.hq_credentials, ret_table_obj=True
)
wedge_address_row = await wa_obj.get_by_name(wedge_address_short_name)
response = await self.call_hq(
"sdk_AddMFAProvider",
sql_parameters=ExecuteStoredProcedure.SqlParameters([
ExecuteStoredProcedure.SqlParam(
ExecuteStoredProcedure.DataType.VarChar,
"short_name",
short_name,
),
ExecuteStoredProcedure.SqlParam(
ExecuteStoredProcedure.DataType.Bit,
"registration_required",
registration_required,
),
ExecuteStoredProcedure.SqlParam(
ExecuteStoredProcedure.DataType.Int,
"wedge_address_id",
wedge_address_row.WedgeAddressID,
),
ExecuteStoredProcedure.SqlParam(
ExecuteStoredProcedure.DataType.Int,
"token_lifetime_in_minutes",
token_lifetime_in_minutes,
),
]),
representation_class_override=MfaProviderRow,
)
return response
async def _prompt_for_add_group_profile(self):
description = textui.query("Please provide a description for this MFA group")
providers = await self.get_providers()
selected_provider: int = SelectMenu(
[MenuOption(x.ShortName, x.MFAProviderID.pyval) for x in providers],
"Select MFA Provider for this MFA Group Profile",
).prompt()
await self.add_group_profile(
description,
selected_auth_provider=selected_provider,
selected_tran_auth_provider=selected_provider,
selected_edv_patrol_provider=selected_provider,
)
[docs]
async def add_group_profile(
self,
description: str,
selected_auth_provider: Optional[int],
selected_tran_auth_provider: Optional[int],
selected_edv_patrol_provider: Optional[int],
):
await self.call_hq(
"sdk_AddMFAGroupProfile",
sql_parameters=ExecuteStoredProcedure.SqlParameters([
ExecuteStoredProcedure.SqlParam(
ExecuteStoredProcedure.DataType.VarChar,
"description",
description,
),
ExecuteStoredProcedure.SqlParam(
ExecuteStoredProcedure.DataType.Int,
"auth_provider_id",
selected_auth_provider,
),
ExecuteStoredProcedure.SqlParam(
ExecuteStoredProcedure.DataType.Int,
"tran_auth_provider_id",
selected_tran_auth_provider,
),
ExecuteStoredProcedure.SqlParam(
ExecuteStoredProcedure.DataType.Int,
"edv_patrol_provider_id",
selected_edv_patrol_provider,
),
]),
)
[docs]
async def update_group_to_mfa_group_profile(self):
profiles = await self.get_mfa_group_profiles()
selected_mfa_group_profile: int = SelectMenu(
[MenuOption(x.Description, x.MFAGroupProfileID.pyval) for x in profiles],
"Select MFA Group Profile",
).prompt()
group = Group(
self.logger, hq_credentials=self.hq_credentials, ret_table_obj=True
)
groups = await group.get()
current_groups = await self.get_group_to_mfa_group_profile_for_profile(
selected_mfa_group_profile
)
current_group_ids = [group["GroupID"] for group in current_groups]
selected_groups = SelectMenu(
[
MenuOption(
x.GroupDesc, x.GroupID.pyval, toggle=x.GroupID in current_group_ids
)
for x in groups
],
"Select Groups for MFA Group Profile",
).prompt()
await self.add_group_to_mfa_group_profile(
selected_mfa_group_profile,
[group for group in selected_groups if group not in current_group_ids],
)
await self.remove_group_to_mfa_group_profile(
selected_mfa_group_profile,
[group for group in current_group_ids if group not in selected_groups],
)
[docs]
async def add_group_to_mfa_group_profile(
self,
selected_mfa_group_profile: int,
selected_groups: list[int],
):
for group_id in selected_groups:
await self.call_hq(
"sdk_AddGroupToMFAGroupProfile",
sql_parameters=ExecuteStoredProcedure.SqlParameters([
ExecuteStoredProcedure.SqlParam(
ExecuteStoredProcedure.DataType.Int,
"mfa_group_profile_id",
selected_mfa_group_profile,
),
ExecuteStoredProcedure.SqlParam(
ExecuteStoredProcedure.DataType.Int,
"group_id",
group_id,
),
]),
)
[docs]
@dev_only
async def remove_group_to_mfa_group_profile(
self, mfa_group_profile_id: int, removed_groups: list[int]
):
for group_id in removed_groups:
await self.call_hq(
"sdk_RemoveGroupToMFAGroupProfile",
sql_parameters=ExecuteStoredProcedure.SqlParameters([
ExecuteStoredProcedure.SqlParam(
ExecuteStoredProcedure.DataType.Int,
"group_id",
int(group_id),
),
ExecuteStoredProcedure.SqlParam(
ExecuteStoredProcedure.DataType.Int,
"mfa_group_profile_id",
int(mfa_group_profile_id),
),
]),
)
[docs]
@dev_only
async def remove_provider(self, short_name: str):
response = await self.call_hq(
"sdk_RemoveMFAProvider",
sql_parameters=ExecuteStoredProcedure.SqlParameters([
ExecuteStoredProcedure.SqlParam(
ExecuteStoredProcedure.DataType.VarChar,
"short_name",
short_name,
),
]),
)
return response
[docs]
@dev_only
async def remove_group_profile(self, mfa_group_profile_id: int):
response = await self.call_hq(
"sdk_RemoveMFAGroupProfile",
sql_parameters=ExecuteStoredProcedure.SqlParameters([
ExecuteStoredProcedure.SqlParam(
ExecuteStoredProcedure.DataType.Int,
"mfa_group_profile_id",
mfa_group_profile_id,
),
]),
)
return response