#!/usr/bin/env python

import xmlrpc.client
import http
import time
import functools
import argparse
import ssl
import os, sys
import socket

# decorator to drive the login if required
def retry_with_login(host, login_fn, auth_check_fn, max_attempts=3, delay=1, logger=None):
    def decorator(func):
        @functools.wraps(func)
        def wrapper(*args, **kwargs):
            attempts = 0
            response = None
            while attempts < max_attempts:
                attempts += 1
                if not auth_check_fn():
                    # No cookie
                    login_success = login_fn()
                if auth_check_fn():
                    # Have cookie try the wrapped function
                    response = func(*args, **kwargs)
                    if response.get('status') == 'error':
                        if response.get('type') in ['NotAuthenticated', 'AuthenticationError']:
                            login_success = login_fn()
                            if login_success:
                                # Have another go at calling the function after login
                                response = func(*args, **kwargs)
                                break
                            else:
                                if attempts >= max_attempts:
                                    raise Exception("Login failed to %s" %(host))
                        else:
                            break
                    else:
                        break

                if attempts < max_attempts:
                    if delay:
                        time.sleep(delay)
            if response is None:
                raise Exception("Failed to retrieve response from %s" %(host))
            return response
        return wrapper
    return decorator

class SafeCookieTimeoutTransport(xmlrpc.client.SafeTransport):
    COOKIE_NAME = 'Session-Cookie'

    def __init__(self, timeout=None, **kwargs): # New
        super().__init__(**kwargs)
        self.timeout = timeout or http.client.socket._GLOBAL_DEFAULT_TIMEOUT
        self.verbose = True
        self.response_headers = None
        self.session_cookie = None

    def set_timeout(self, timeout):
        self.timeout = timeout

    def make_connection(self, host): # Original function from Python-3.10.5
        if self._connection and host == self._connection[0]:
            return self._connection[1]
        if not hasattr(http.client, "HTTPSConnection"):
            raise NotImplementedError(
            "your version of http.client doesn't support HTTPS")
        # create a HTTPS connection object from a host descriptor
        # host may be a string, or a (host, x509-dict) tuple
        chost, self._extra_headers, x509 = self.get_host_info(host)
        self._connection = host, http.client.HTTPSConnection(chost, None,
                                                             timeout=self.timeout, # This line added
                                                             context=self.context, **(x509 or {}))
        return self._connection[1]

    def send_headers(self, connection, headers):
        if self.session_cookie:
            connection.putheader(self.COOKIE_NAME, self.session_cookie)
        super().send_headers(connection, headers)

    def parse_response(self, response):
        self.response_headers = response.headers
        self.status_code = response.status
        self.session_cookie = self.response_headers.get(self.COOKIE_NAME)
        parsed_response = super().parse_response(response)
        if self.status_code == 200:
            if parsed_response[0].get('status') == 'error':
                if parsed_response[0].get('type') in ['NotAuthenticated', 'AuthenticationError']:
                    self.session_cookie = None
        return parsed_response

    def is_authenticated(self):
        return bool(self.session_cookie)

class SystemProxy(object):
    def __init__(self, host, password, timeout=None, context=None, logger=None):
        self.host = host
        self.password = password
        self.timeout = timeout
        self.context = context
        self.logger = logger
        self.transport = SafeCookieTimeoutTransport(self.timeout, context=self.context)
        self._build_proxy()

    def _set_password(self, password):
        if password != self.password:
            self.password = password

    def _is_authenticated(self):
        return self.transport.is_authenticated()

    def handle_exception(self, e):
        if isinstance(e, socket.gaierror):
            errorMsg = f"{self.host}: Name or service not known"
        elif isinstance(e, socket.error):
            errorMsg = f"Unable to communicate with {self.host}"
        else:
            errorMsg = str(e)
        raise Exception(errorMsg)

    def _login(self):
        try:
            response = self.proxy.Login('admin', self.password)
        except Exception as e:
            self.handle_exception(e)
        if response.get('type') in ['NotAuthenticated', 'AuthenticationError']:
            raise Exception("Login failure to %s" %(self.host))
        return self._is_authenticated()

    def __getattr__(self, name):
        attr = getattr(self.proxy, name)
        if callable(attr):
            @retry_with_login(self.host, self._login, self._is_authenticated, max_attempts=1, delay=1, logger=self.logger)  # Apply retry decorator
            def wrapper(*args, **kwargs):
                try:
                    return attr(*args, **kwargs)
                except Exception as e:
                    self.handle_exception(e)
            return wrapper
        else:
            return attr

    def _build_proxy(self):
        self.url = 'https://%s/RPC2/' %(self.host)
        self.proxy = xmlrpc.client.ServerProxy(self.url, transport=self.transport, encoding='utf-8', allow_none=False)

    def _set_timeout(self, timeout):
        self.timeout = timeout
        self.transport.set_timeout(timeout)

    def _set_host(self, host):
        if host != self.host:
            self.host = host
            self._build_proxy()


def parse_defaults_file(filepath):
    defaults = {}
    try:
        with open(filepath, 'r') as f:
            for line in f:
                line = line.split('#', 1)[0].strip()
                if not line:
                    continue  # Skip empty lines and comments
                if '=' in line:
                    key, value = line.split('=', 1)
                    defaults[key.strip().replace('-', '_')] = value.strip()
    except Exception as e:
        sys.stderr.write(f"Error reading defaults file: {e}\n")
        sys.exit(1)
    return defaults

def parse_cmdline():
    # Parse known args first to get defaults file
    pre_parser = argparse.ArgumentParser(add_help=False)
    pre_parser.add_argument('--defaults', help='Path to defaults file (key=value per line)')
    args, remaining = pre_parser.parse_known_args()

    parser = argparse.ArgumentParser(parents=[pre_parser], description="SvHCI VM import from vSphere ESXi/VCSA")

    parser.add_argument('--hostname', dest='hostname', help='SvHCI hostname')
    parser.add_argument('--password', dest='password', help='Password for authentication')
    parser.add_argument('--vsphere-hostname', dest='vsphere_hostname', help='vSphere ESXi/VCSA address')
    parser.add_argument('--vsphere-username', dest='vsphere_username', help='Username for vSphere authentication')
    parser.add_argument('--vsphere-password', dest='vsphere_password', help='Password for vSphere authentication')
    parser.add_argument('--vm-name', dest='vm_name', help='Name of the virtual machine')
    parser.add_argument('--pool', dest='pool', help='Name of the SvHCI storage pool')
    parser.add_argument('--remote-pool', dest='remote_pool', help='Name of the remote SvHCI storage pool (adds VM protection)')

    defaults = {}

    if args.defaults:
        defaults = parse_defaults_file(args.defaults)

    parser.set_defaults(**defaults)
    args = parser.parse_args(remaining)

    return args

def validate_args(args):
    required = []
    if args.hostname is None:
        required.append("hostname")
    if args.password is None:
        required.append("password")
    if args.vsphere_hostname is None:
        required.append("vsphere-hostname")
    if args.vsphere_username is None:
        required.append("vsphere-username")
    if args.vsphere_password is None:
        required.append("vsphere-password")
    if args.pool is None:
        required.append("pool")
    if args.vm_name is None:
        required.append("vm-name")

    if required:
        fullNames = ["--" + arg for arg in required]
        sys.stderr.write("The following required arguments were not provided: %s\n" %(', '.join(fullNames)))
        sys.exit(1)


def main():
    args = parse_cmdline()

    validate_args(args)

    importSpec = {'hostname': args.vsphere_hostname,
                  'username' : args.vsphere_username,
                  'password' : args.vsphere_password,
                  'vmName': args.vm_name,
                  'pool': args.pool,
                  'validate': True,
                  'mapping': {},
                 }

    if args.remote_pool is not None:
        importSpec['remotePool'] = args.remote_pool

    proxy  = SystemProxy(args.hostname, args.password, context=ssl._create_unverified_context())
    computeRef = {'uuid': 0,
                  'type': 'Compute'}
    response = proxy.ImportVMTask(computeRef, importSpec)
    if response.get('status') == 'error':
        faultMessage = response.get('error', {}).get('faultMessage', 'No error supplied')
        errMsg = "Error: Failed to start VM import task\n%s\n" %(faultMessage)
        sys.stderr.write(errMsg)
        sys.exit(1)
        
    print("Import task started on %s for VM '%s', from %s" %(args.hostname, args.vm_name, args.vsphere_hostname))

    return response


if __name__ == '__main__':
    try:
        main()
    except Exception as e:
        errorMessage = str(e)
        sys.stderr.write(f"Error: {errorMessage}\n")
        sys.exit(1)

