#!/opt/local/bin/python3.12

# oauth2nmail
#
# Copyright (c) 2021-2024 Kristofer Berggren
# All rights reserved.
#
# nmail is distributed under the MIT license, see LICENSE for details.

from http.server import BaseHTTPRequestHandler, HTTPServer
from urllib.parse import urlparse
import http.client
import json
import os
import stat
import sys
import threading
import time
import urllib.parse
import webbrowser

version = "1.03"
pathServed = ""
providers = {
    "gmail-oauth2": {
        "auth_scope": "openid profile email https://mail.google.com/",
        "mail_scope": "openid profile email https://mail.google.com/",
        "user_scope": "openid profile email https://mail.google.com/",
        "auth_url": "https://accounts.google.com/o/oauth2/auth",
        "conv_url": "https://accounts.google.com/o/oauth2/token",
        "refr_url": "https://accounts.google.com/o/oauth2/token",
        "info_url": "https://openidconnect.googleapis.com/v1/userinfo",
    },
    "outlook-oauth2": {
        "auth_scope": "offline_access https://outlook.office.com/IMAP.AccessAsUser.All https://outlook.office.com/SMTP.Send https://graph.microsoft.com/User.Read",
        "mail_scope": "offline_access https://outlook.office.com/IMAP.AccessAsUser.All https://outlook.office.com/SMTP.Send",
        "user_scope": "offline_access https://graph.microsoft.com/User.Read",
        "auth_url": "https://login.microsoftonline.com/common/oauth2/v2.0/authorize",
        "conv_url": "https://login.microsoftonline.com/common/oauth2/v2.0/token",
        "refr_url": "https://login.microsoftonline.com/common/oauth2/v2.0/token",
        "info_url": "https://graph.microsoft.com/v1.0/me",
    },
}


def show_help():
    print("oauth2nmail is a utility tool for handling OAuth2 authentication and")
    print("token renewals. Primarily written for use by nmail email client.")
    print("Parameters must be passed using environment variables, see examples below.")
    print("")
    print("Usage: oauth2nmail --generate")
    print("   or: oauth2nmail --refresh")
    print("   or: oauth2nmail --help")
    print("   or: oauth2nmail --version")
    print("")
    print("Options:")
    print("   -g, --generate    perform authentication and generate refresh/access tokens")
    print("   -r, --refresh     refresh access token using refresh token")
    print("   -h, --help        display this help and exit")
    print("   -v, --version     output version information and exit")
    print("")
    print("Return values:")
    print("   0                 success")
    print("   1                 syntax / usage error")
    print("   2                 authentication timeout")
    print("   3                 user permission not granted")
    print("   4                 http request network error")
    print("   5                 http request returned non-200")
    print("   6                 access token not available")
    print("   7                 refresh token not available")
    print("   130               keyboard interrupt (ctrl-c)")
    print("")
    print("Examples:")
    print("   OAUTH2_TYPE=\"gmail-oauth2\" OAUTH2_CLIENT_ID=\"9\" OAUTH2_CLIENT_SECRET=\"j\" \\")
    print("   OAUTH2_TOKEN_STORE=\"${HOME}/.config/nmail/oauth2.tokens\" oauth2nmail -g")
    print("")
    print("   OAUTH2_TYPE=\"gmail-oauth2\" OAUTH2_CLIENT_ID=\"9\" OAUTH2_CLIENT_SECRET=\"j\" \\")
    print("   OAUTH2_TOKEN_STORE=\"${HOME}/.config/nmail/oauth2.tokens\" oauth2nmail -r")
    print("")
    print("Report bugs at https://github.com/d99kris/nmail")
    print("")


def show_version():
    print("oauth2nmail v" + version)
    print("")
    print("Copyright (c) 2021-2024 Kristofer Berggren")
    print("")
    print("oauth2nmail is distributed under the MIT license.")
    print("")
    print("Written by Kristofer Berggren.")


# HttpResponse, get() and post() below constitutes a poor man's 'requests' replacement
class HttpResponse:
    def __init__(self, status_code, headers, text):
        self.status_code = status_code
        self.headers = headers
        self.text = text


def get(url, headers=None):
    if headers is None:
        headers = {}

    # Parse the URL
    parsed_url = urllib.parse.urlparse(url)
    connection = http.client.HTTPSConnection(parsed_url.netloc) if parsed_url.scheme == 'https' else http.client.HTTPConnection(parsed_url.netloc)

    # Construct the path with query parameters if any
    path = parsed_url.path
    if parsed_url.query:
        path += '?' + parsed_url.query

    # Send the GET request
    connection.request("GET", path, headers=headers)

    # Get the response
    response = connection.getresponse()
    response_data = response.read().decode()

    # Close the connection
    connection.close()

    return HttpResponse(response.status, response.getheaders(), response_data)


def post(url, data, headers=None):
    if headers is None:
        headers = {"Content-Type": "application/json"}

    # Parse the URL
    parsed_url = urllib.parse.urlparse(url)
    connection = http.client.HTTPSConnection(parsed_url.netloc) if parsed_url.scheme == 'https' else http.client.HTTPConnection(parsed_url.netloc)

    # Encode the data
    if headers.get("Content-Type") == "application/x-www-form-urlencoded":
        if isinstance(data, dict):
            data = urllib.parse.urlencode(data)
        data = data.encode('utf-8')
    elif headers.get("Content-Type") == "application/json":
        if isinstance(data, dict):
            data = json.dumps(data)
        data = data.encode('utf-8')
    else:
        raise ValueError("Unsupported Content-Type")

    # Send the POST request
    connection.request("POST", parsed_url.path, body=data, headers=headers)

    # Get the response
    response = connection.getresponse()
    response_data = response.read().decode()

    # Close the connection
    connection.close()

    return HttpResponse(response.status, response.getheaders(), response_data)


def url_params(params):
    param_list = []
    for param in params.items():
        param_list.append("%s=%s" % (param[0], urllib.parse.quote(param[1], safe="~-._")))
    return "&".join(param_list)


def save_tokens(tokenStore, tokenItems):
    with open(tokenStore, "w") as tokenFile:
        for key, value in tokenItems:
            tokenFile.write(str(key) + "=" + str(value) + "\n")


def read_tokens(tokenStore):
    tokens = {}
    if os.path.isfile(tokenStore):
        with open(tokenStore, "r") as tokenFile:
            tokenLines = tokenFile.read().splitlines()
            for tokenLine in tokenLines:
                tokenWords = tokenLine.split("=", 1)
                tokens[str(tokenWords[0])] = tokenWords[1]
    return tokens


def generate(provider, client_id, client_secret, token_store):
    localHostName = "localhost"
    localHostPort = 6880
    redirectPage = "/oauth2nmail"
    redirectUri = "http://" + localHostName + ":" + str(localHostPort) + redirectPage

    # open default web browser for user authentication
    params = {}
    params["client_id"] = client_id
    params["redirect_uri"] = redirectUri
    params["scope"] = providers[provider]["auth_scope"]
    params["response_type"] = "code"
    url = providers[provider]["auth_url"] +  "?" + url_params(params)
    webbrowser.open(url, new = 0, autoraise = True)

    # local web server handling redirect
    class LocalServer(BaseHTTPRequestHandler):
        def log_message(self, format, *args):
            return

        def do_GET(self):
            self.send_response(200)
            self.send_header("Content-type", "text/html")
            self.end_headers()
            if self.path.startswith(redirectPage):
                self.wfile.write(bytes("<html><head><title>Nmail Oauth2</title></head>", "utf-8"))
                self.wfile.write(bytes("<body><h2>Authentication successful</h2>", "utf-8"))
                self.wfile.write(bytes("<p>You may close this browser window now.</p>", "utf-8"))
                self.wfile.write(bytes("</body></html>", "utf-8"))
                global pathServed
                pathServed = self.path

    localServer = HTTPServer((localHostName, localHostPort), LocalServer)
    thread = threading.Thread(target = localServer.serve_forever)
    thread.daemon = True
    thread.start()
    secs = 0
    while not pathServed and secs < 60:
        time.sleep(1)
        secs = secs + 1

    if not pathServed:
        sys.stderr.write("authentication timeout\n")
        return 2

    query = urlparse(pathServed).query
    queryDict = dict(qc.split("=") for qc in query.split("&"))
    if "code" not in queryDict:
        sys.stderr.write("user did not grant permission\n")
        return 3

    code = queryDict["code"]

    # exchange auth code for mail access and refresh token
    convParams = {}
    convParams["client_id"] = client_id
    convParams["code"] = code
    convParams["redirect_uri"] = redirectUri
    convParams["grant_type"] = "authorization_code"
    convParams["scope"] = providers[provider]["mail_scope"]
    if provider == "gmail-oauth2":
        convParams["client_secret"] = client_secret

    convUrl = providers[provider]["conv_url"]
    convHdr = {'Content-Type': 'application/x-www-form-urlencoded'}
    try:
        convResponse = post(convUrl, data=convParams, headers=convHdr)
    except Exception as e:
        sys.stderr.write("token request http post failed " + str(e) + "\n")
        return 4

    if convResponse.status_code != 200:
        sys.stderr.write("token request failed " + str(convResponse) + "\n")
        sys.stderr.write(str(json.loads(convResponse.text)) + "\n")
        return 5

    # save tokens
    jsonResponse = json.loads(convResponse.text)
    save_tokens(token_store, jsonResponse.items())
    tokens = read_tokens(token_store)

    # ensure access_token is present
    if "access_token" in tokens:
        access_token = tokens["access_token"]
    else:
        sys.stderr.write("access_token not available\n")
        return 6

    # use special access token for outlook user id
    if provider == "outlook-oauth2":
        # exchange auth code for user info access token
        convParams = {}
        convParams["client_id"] = client_id
        convParams["grant_type"] = "refresh_token"
        convParams["refresh_token"] = tokens["refresh_token"]
        convParams["scope"] = providers[provider]["user_scope"]
        convParams["requested_token_use"] = "on_behalf_of"
        convUrl = providers[provider]["conv_url"]
        convHdr = {'Content-Type': 'application/x-www-form-urlencoded'}
        try:
            convResponse = post(convUrl, data=convParams, headers=convHdr)
        except Exception as e:
            sys.stderr.write("user token request http post failed " + str(e) + "\n")
            return 4

        if convResponse.status_code != 200:
            sys.stderr.write("user token request failed " + str(convResponse) + "\n")
            sys.stderr.write(str(json.loads(convResponse.text)) + "\n")
            return 5

        # ensure access_token is present
        jsonResponse = json.loads(convResponse.text)
        infoTokens = jsonResponse.items()
        for key, value in infoTokens:
            if key == "access_token":
                access_token = value

        if not access_token:
            sys.stderr.write("user access_token not available\n")
            return 6

    # request email address
    infoUrl = providers[provider]["info_url"]
    infoHeaders = {'Authorization': 'Bearer ' + access_token}
    try:
        infoResponse = get(infoUrl, headers=infoHeaders)
    except Exception as e:
        sys.stderr.write("email address get request failed " + str(e) + "\n")
        return 4

    if infoResponse.status_code != 200:
        sys.stderr.write("email address request failed " + str(infoResponse) + "\n")
        sys.stderr.write(str(json.loads(infoResponse.text)) + "\n")
        return 5

    # save user info (email, name)
    jsonInfoResponse = json.loads(infoResponse.text)
    emailItems = jsonInfoResponse.items()
    for key, value in emailItems:
        if key == "email" or key == "name":
            tokens[key] = value
        elif key == "userPrincipalName":
            tokens["email"] = value
        elif key == "displayName":
            tokens["name"] = value

    save_tokens(token_store, tokens.items())
    return 0


def refresh(provider, client_id, client_secret, token_store):
    # read token store
    tokens = read_tokens(token_store)

    # ensure refresh_token is present
    if "refresh_token" in tokens:
        refresh_token = tokens["refresh_token"]
    else:
        sys.stderr.write("refresh_token not set in " + token_store + "\n")
        return 7

    # use refresh code to request new access token
    refrParams = {}
    refrParams["client_id"] = client_id

    refrParams["refresh_token"] = refresh_token
    refrParams["grant_type"] = "refresh_token"
    if provider == "gmail-oauth2":
        refrParams["client_secret"] = client_secret

    refrUrl = providers[provider]["refr_url"]
    refrHdr = {'Content-Type': 'application/x-www-form-urlencoded'}
    try:
        refrResponse = post(refrUrl, data=refrParams, headers=refrHdr)
    except Exception as e:
        sys.stderr.write("token refresh http post failed " + str(e) + "\n")
        return 4

    # parse response
    jsonRefrResponse = json.loads(refrResponse.text)

    if (refrResponse.status_code == 400):
        isInvalidGrant = (jsonRefrResponse["error"] == "invalid_grant")
        isDescGmailExpired = (jsonRefrResponse["error_description"] == "Token has been expired or revoked.")
        isDescOutlookExpired = ("The user could not be authenticated as the grant is expired. The user must sign in again." in jsonRefrResponse["error_description"])
        if isInvalidGrant and (isDescGmailExpired or isDescOutlookExpired):
            sys.stderr.write("token expired " + str(refrResponse) + "\n")
            sys.stderr.write(str(json.loads(refrResponse.text)) + "\n")
            sys.stderr.write("attempt to reauthenticate\n")
            return generate(provider, client_id, client_secret, token_store)

    if refrResponse.status_code != 200:
        sys.stderr.write("token refresh failed " + str(refrResponse) + "\n")
        sys.stderr.write(str(json.loads(refrResponse.text)) + "\n")
        return 5

    # save received tokens
    refreshedTokens = jsonRefrResponse.items()
    for key, value in refreshedTokens:
        tokens[key] = value

    save_tokens(token_store, tokens.items())
    return 0


def main(argv):
    # Global setting
    os.umask(stat.S_IRWXG | stat.S_IRWXO)

    # Process arguments
    if len(sys.argv) != 2:
        show_help()
        sys.exit(1)
    elif (sys.argv[1] == "--help") or (sys.argv[1] == "-h"):
        show_help()
        sys.exit(0)
    elif (sys.argv[1] == "--version") or (sys.argv[1] == "-v"):
        show_version()
        sys.exit(0)

    # Determine operation
    isGenerate = (sys.argv[1] == "--generate") or (sys.argv[1] == "-g")
    isRefresh = (sys.argv[1] == "--refresh") or (sys.argv[1] == "-r")

    if not isGenerate and not isRefresh:
        show_help()
        sys.exit(1)

    # Read required environment variables
    provider = os.getenv("OAUTH2_TYPE")
    if not provider:
        sys.stderr.write("env OAUTH2_TYPE not set\n")
        sys.exit(1)
    elif provider not in providers:
        sys.stderr.write("OAUTH2_TYPE provider " + provider + " not supported\n")
        sys.exit(1)

    client_id = os.getenv("OAUTH2_CLIENT_ID")
    if not client_id:
        sys.stderr.write("env OAUTH2_CLIENT_ID not set\n")
        sys.exit(1)

    client_secret = os.getenv("OAUTH2_CLIENT_SECRET")
    if not client_secret:
        sys.stderr.write("env OAUTH2_CLIENT_SECRET not set\n")
        sys.exit(1)

    token_store = os.getenv("OAUTH2_TOKEN_STORE")
    if not token_store:
        sys.stderr.write("env OAUTH2_TOKEN_STORE not set\n")
        sys.exit(1)

    # Perform requested operation
    if isGenerate:
        rv = generate(provider, client_id, client_secret, token_store)
        sys.exit(rv)
    elif isRefresh:
        rv = refresh(provider, client_id, client_secret, token_store)
        sys.exit(rv)


if __name__ == "__main__":
    main(sys.argv)
