import requests
import json
from urllib.parse import quote

from commonAuth import *

# This is for getting SAML user information, it is an alternative to using SAML attribute
# query requests (AQR) which Azure AD does not support.
#
# Provide Azure API key credentials and base url in the authentication.conf
# file or using the Splunk Web UI
# (Settings > Authentication Methods > SAML Configuration > Authentication Extensions)
# and use the Azure API to extract user information.
#
# In authentication.conf, configure the 'scriptSecureArguments' setting to
# "apiKey:<your Azure API key>". For example:
#
# scriptSecureArguments = apiKey:<your Azure API key string>,baseUrl:<your Azure url>
#
# After you restart the Splunk platform, the platform encrypts your Azure credentials.
# For more information about Splunk platform configuration files, search the
# Splunk documentation for "about configuration files".
#
# In Splunk Web UI under Authentication Extensions > Script Secure Arguments:
# key = apiKey, value = <your Azure API key string>

USER_ENDPOINT = 'https://graph.microsoft.com/v1.0/users/'
LOGIN_ENDPOINT = 'https://login.microsoftonline.com/'
GRAPH_SCOPE = 'https://graph.microsoft.com/.default'
CLIENT_CREDENTIALS = 'client_credentials'
GROUP_TYPE = 'groupType'
request_timeout = 10

def getAuthToken(tenantId, clientId, clientSecret, logger):
    tokenEndpoint = LOGIN_ENDPOINT + tenantId + "/oauth2/v2.0/token"  # To Generate OAuth2 Token

    # Retrieve Auth Token from Azure
    body = {
            'grant_type': CLIENT_CREDENTIALS,
            'scope': GRAPH_SCOPE,
            'client_id': clientId,
            'client_secret': clientSecret
             }

    logger.info("Requesting Authentication Token for client={}".format(clientId))

    auth_response = requests.post(tokenEndpoint, data=body, timeout=request_timeout)

    if auth_response.status_code != 200:
        logger.error("Failed to retrieve authorization token for client={}".format(clientId))
        return FAILED

    try:
        auth_responseSTR = json.loads(auth_response.text)
    except Exception as e:
        logger.error("Failed to retrieve authorization token for client={} with error={}".format(clientId, str(e)))
        return FAILED
    return auth_responseSTR['access_token']

def getUserInfo(args, logger):

    apiKey = getAuthToken(args['tenantId'], args['clientId'], args['clientSecret'], logger)

    # Assuming the username passed in is in the form of an email address corresponding
    # to the Azure user.
    username = args['username']

    API_KEY_HEADER = 'Bearer ' + apiKey
    AZURE_HEADERS = {'Host': 'graph.microsoft.com', 'Authorization': API_KEY_HEADER}

    encoded_username = quote(username)
    realNameStr = ''
    fullString = ''
    rolesString = ''

    usernameUrl = USER_ENDPOINT + encoded_username
    usernameResponse = requests.request('GET', usernameUrl, headers=AZURE_HEADERS, timeout=request_timeout)

    if usernameResponse.status_code != 200:
        logger.warning("Failed to get user info for username={} with status={} and response={}".format(username, usernameResponse.status_code, usernameResponse.text))
        return FAILED

    try:
        nameAttributes = json.loads(usernameResponse.text)
    except Exception as e:
        logger.warning("Failed to parse user info for username={} with error={}".format(username, str(e)))
        return FAILED

    # Construct a groups endpoint with the user's object ID
    groupsUrl = USER_ENDPOINT + encoded_username

    if GROUP_TYPE in args and args[GROUP_TYPE] == 'transitive':
        logger.info("Using transitive groups endpoint to query groups for username={}".format(username))
        groupsUrl += '/transitiveMemberOf'
    else:
        logger.info("Using direct groups endpoint to query groups for username={}".format(username))
        groupsUrl += '/memberOf'

    groupsUrl += '?$top=999'
    while groupsUrl:
        groupsResponse = requests.request('GET', groupsUrl, headers=AZURE_HEADERS, timeout=request_timeout)
        if groupsResponse.status_code != 200:
            logger.error("Failed to get user group membership for username={} with status={} and response={}".format(username, groupsResponse.status_code, groupsResponse.text))
            return FAILED

        try:
            groupsResponseSTR = json.loads(groupsResponse.text)
        except Exception as e:
            logger.error("Failed to parse user groups response for username={} with error={}".format(username, str(e)))
            return FAILED

        if groupsResponseSTR['value']:
            groupIds = [urlsafe_b64encode_to_str(group['id']) for group in groupsResponseSTR['value']]
            rolesString += ":".join(groupIds)
            if '@odata.nextLink' in groupsResponseSTR:
                groupsUrl = groupsResponseSTR['@odata.nextLink']
            else:
                groupsUrl = None
    # Returning the id associated with each group the user is a part of SAML has to be set up to use group id
    # from Azure AD as SAML group name Ref: customer case &
    # https://www.splunk.com/en_us/blog/cloud/configuring-microsoft-s-azure-security-assertion-markup-language
    # -saml-single-sign-on-sso-with-splunk-cloud-azure-portal.htm

    base64UrlEncodedUsername = urlsafe_b64encode_to_str(username)
    base64UrlEncodedRealName = urlsafe_b64encode_to_str(realNameStr)

    fullString += '{} --userInfo={};{};{} --encodedOutput=true'.format(SUCCESS, base64UrlEncodedUsername, base64UrlEncodedRealName, rolesString)
    return fullString

def login(args, logger):
    # Assuming the username passed in is in the form of an email address corresponding
    # to the Azure user.
    username = args['userInfo'].split(';')[0]
    
    apiKey = getAuthToken(args['tenantId'], args['clientId'], args['clientSecret'], logger)

    API_KEY_HEADER = 'Bearer ' + apiKey
    AZURE_HEADERS = {'Host': 'graph.microsoft.com', 'Authorization': API_KEY_HEADER}
    encoded_username = quote(username)
    realNameStr = ''
    fullString = ''
    rolesString = ''
    usernameUrl = USER_ENDPOINT + encoded_username
    usernameResponse = requests.request('GET', usernameUrl, headers=AZURE_HEADERS, timeout=request_timeout)
    if usernameResponse.status_code != 200:
        logger.error("Failed to get user info for username={} with status={} and response={}".format(username, usernameResponse.status_code, usernameResponse.text))
        return FAILED
    try:
        nameAttributes = json.loads(usernameResponse.text)
    except Exception as e:
        logger.error("Failed to parse user info for username={} with error={}".format(username, str(e)))
        return FAILED
    # Construct a groups endpoint with the user's object ID
    groupsUrl = USER_ENDPOINT + encoded_username

    if GROUP_TYPE in args and args[GROUP_TYPE] == 'transitive':
        logger.info("Using transitive groups endpoint to query groups for username={}".format(username))
        groupsUrl += '/transitiveMemberOf'
    else:
        logger.info("Using direct groups endpoint to query groups for username={}".format(username))
        groupsUrl += '/memberOf'

    groupsUrl += '?$top=999'

    while groupsUrl:
        groupsResponse = requests.request('GET', groupsUrl, headers=AZURE_HEADERS, timeout=request_timeout)
        if groupsResponse.status_code != 200:
            logger.error("Failed to get user group membership info for username={} with status={} and response={}".format(username, groupsResponse.status_code, groupsResponse.text))
            return FAILED

        try:
            groupsResponseSTR = json.loads(groupsResponse.text)
        except Exception as e:
            logger.error("Failed to parse user groups response for username={} with error={}".format(username, str(e)))
            return FAILED 
        allgroups = []
        if groupsResponseSTR['value']:
            groupIds = [urlsafe_b64encode_to_str(group['id']) for group in groupsResponseSTR['value']]
            allgroups += groupIds
            if '@odata.nextLink' in groupsResponseSTR:
                groupsUrl = groupsResponseSTR['@odata.nextLink']
            else:
                groupsUrl = None
    # Returning the id associated with each group the user is a part of SAML has to be set up to use group id
    # from Azure AD as SAML group name Ref: customer case &
    # https://www.splunk.com/en_us/blog/cloud/configuring-microsoft-s-azure-security-assertion-markup-language
    # -saml-single-sign-on-sso-with-splunk-cloud-azure-portal.htm
    for i in range(len(allgroups)):
        rolesString += '--groups={} '.format(allgroups[i])
    fullString += '{} {} --encodedOutput=true'.format(SUCCESS, rolesString)
    return fullString

if __name__ == "__main__":
    callName = sys.argv[1]
    dictIn = readInputs()
    logger = getLogger("{}/splunk_scripted_authentication_azure.log".format(logPath), "azure")

    if callName == "getUserInfo":
        response = getUserInfo(dictIn, logger)
        print(response)
    if callName == "login":
        response = login(dictIn, logger)
        print(response)

