# StarPU --- Runtime system for heterogeneous multicore architectures.
#
# Copyright (C) 2020-2023  Universit'e de Bordeaux, CNRS (LaBRI UMR 5800), Inria
#
# StarPU is free software; you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published by
# the Free Software Foundation; either version 2.1 of the License, or (at
# your option) any later version.
#
# StarPU is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
#
# See the GNU Lesser General Public License in COPYING.LGPL for more details.
#

import sys
import types
import joblib as jl
from joblib import logger
from joblib._parallel_backends import ParallelBackendBase
from starpu import starpupy
from starpu import Handle
import starpu
import asyncio
import math
import functools
try:
    import numpy as np
    has_numpy=True
except:
    has_numpy=False
import inspect
import threading

loop = asyncio.get_event_loop()
if (loop.is_running()):
	try:
		import nest_asyncio
		nest_asyncio.apply()
		has_nest=True
	except (ModuleNotFoundError, ImportError):
		has_nest=False

BACKENDS={
	#'loky': LokyBackend,
}
_backend = threading.local()

# get the number of CPUs controlled by StarPU
def cpu_count():
	n_cpus=starpupy.worker_get_count_by_type(starpu.STARPU_CPU_WORKER)
	return n_cpus

# split a list ls into n_block numbers of sub-lists
def partition(ls, n_block):
	if len(ls)>=n_block:
		# there are n1 sub-lists which contain q1 elements, and (n_block-n1) sublists which contain q2 elements (n1 can be 0)
		q1=math.ceil(len(ls)/n_block)
		q2=math.floor(len(ls)/n_block)
		n1=len(ls)%n_block
		#n2=n_block-n1
		# generate n1 sub-lists in L1, and (n_block-n1) sub-lists in L2
		L1=[ls[i:i+q1] for i in range(0, n1*q1, q1)]
		L2=[ls[i:i+q2] for i in range(n1*q1, len(ls), q2)]

		L=L1+L2
	else:
		# if the block number is larger than the length of list, each element in the list is a sub-list
		L=[ls[i:i+1] for i in range (len(ls))]
	return L

# split a two-dimension numpy matrix into n_block numbers of sub-matrices
def array2d_split(a, n_block):
	# decompose number of n_jobs to two integers multiply
	c_tmp=math.floor(math.sqrt(n_block))
	for i in range (c_tmp,0,-1):
		if n_block%i==0:
			c=i
			r=int(n_block/c)
			break
	# split column
	arr_split_c=np.array_split(a,c,0)
	arr_split=[]
	# split row
	for i in range(c):
		arr_split_r=np.array_split(arr_split_c[i],r,1)
		for j in range(r):
			arr_split.append(arr_split_r[j].copy(order='C'))
	return arr_split


def future_generator(iterable, n_jobs, dict_task):
	# iterable is generated by delayed function, after converting to a list, the format is [function, (arg1, arg2, ... ,)]
	#print("iterable type is ", type(iterable))
	#print("iterable is", iterable)
	# get the number of block
	if n_jobs<-cpu_count()-1:
		raise SystemExit('Error: n_jobs is out of range, number of CPUs is', cpu_count())
	elif n_jobs<0:
		n_block=cpu_count()+1+n_jobs
	else:
		n_block=n_jobs
	if (n_block <= 0):
		n_block = 1

	# if arguments is tuple format
	if type(iterable) is tuple:
		# the function is always the first element
		f=iterable[0]
		# get the name of formal arguments of f
		formal_args=inspect.getfullargspec(f).args
		# get the arguments list
		args=[]
		# argument is arbitrary in iterable[1]
		args=list(iterable[1])
		# argument is keyword argument in iterable[2]
		for i in range(len(formal_args)):
			for j in iterable[2].keys():
				if j==formal_args[i]:
					args.append(iterable[2][j])
		# check whether all arrays have the same size
		l_arr=[]
		# list of Future result
		L_fut=[]
		# split the vector
		args_split=[]
		# handle list
		arg_h=[]
		for i in range(len(args)):
			args_split.append([])
			# if the array is an numpy array
			if has_numpy and type(args[i]) is np.ndarray:
				# check whether the arg is already registered
				handle_dict = starpu.handle_dict
				if handle_dict.get(id(args[i]))==None:
					arr_h = Handle(args[i])
					arg_h.append(arr_h)
					args_split[i] = arr_h.partition(n_block, 0)
				else:
					arr_h = handle_dict.get(id(args[i]))
					arg_h.append(arr_h)
					args_split[i] = arr_h.partition(n_block, 0)
			# if the array is a generator
			elif isinstance(args[i],types.GeneratorType):
				# split generator
				args_split[i]=partition(list(args[i]),n_block)
				arg_h.append(None)
				# get the length of generator
				l_arr.append(sum(len(args_split[i][j]) for j in range(len(args_split[i]))))
			else:
				arg_h.append(None)
		if len(set(l_arr))>1:
			raise SystemExit('Error: all arrays should have the same size')
		#print("args list is", args_split)
		for i in range(n_block):
			# generate the argument list
			L_args=[]
			sizebase=0
			for j in range(len(args)):
				if (has_numpy and type(args[j]) is np.ndarray):
					L_args.append(args_split[j][i])
					n_arr = arg_h[j].get_partition_size(args_split[j])
					if sizebase==0:
						sizebase=n_arr[i]
					elif sizebase==n_arr[i]:
						continue
					else:
						raise SystemExit('Error: all arrays should be split into equal size')
				elif isinstance(args[j],types.GeneratorType):
					L_args.append(args_split[j][i])
					if sizebase==0:
						sizebase=len(args_split[j][i])
					elif sizebase==len(args_split[j][i]):
						continue
					else:
						raise SystemExit('Error: all arrays should be split into equal size')
				else:
					L_args.append(args[j])
			#print("L_args is", L_args)
			fut=starpu.task_submit(name=dict_task['name'], synchronous=dict_task['synchronous'], priority=dict_task['priority'],\
								   color=dict_task['color'], flops=dict_task['flops'], perfmodel=dict_task['perfmodel'], sizebase=sizebase,\
								   ret_handle=dict_task['ret_handle'], ret_fut=dict_task['ret_fut'], arg_handle=dict_task['arg_handle'], modes=dict_task['modes'])\
				                  (f, *L_args)
			L_fut.append(fut)
		# unpartition and unregister the numpy array
		for i in range(len(args)):
			if (has_numpy and type(args[i]) is np.ndarray):
				arg_h[i].unpartition(args_split[i], n_block)
				arg_h[i].unregister()
		return L_fut

	# if iterable is a generator or a list of function
	else:
		L=list(iterable)
		#print(L)
		# generate a list of function according to iterable
		def lf(ls):
			L_func=[]
			for i in range(len(ls)):
				# the first element is the function
				f=ls[i][0]
				# the second element is the args list of a type tuple
				L_args=list(ls[i][1])
				# generate a list of function
				L_func.append(f(*L_args))
			return L_func

		# generate the split function list
		L_split=partition(L,n_block)
		# operation in each split list
		L_fut=[]
		for i in range(len(L_split)):
			sizebase=len(L_split[i])
			fut=starpu.task_submit(name=dict_task['name'], synchronous=dict_task['synchronous'], priority=dict_task['priority'],\
								   color=dict_task['color'], flops=dict_task['flops'], perfmodel=dict_task['perfmodel'], sizebase=sizebase,\
								   ret_handle=dict_task['ret_handle'], ret_fut=dict_task['ret_fut'], arg_handle=dict_task['arg_handle'], modes=dict_task['modes'])\
				                  (lf, L_split[i])
			L_fut.append(fut)
		return L_fut

class Parallel(object):
	def __init__(self, mode="normal", perfmodel=None, end_msg=None,\
			 name=None, synchronous=0, priority=0, color=None, flops=None,\
			 ret_handle=False, ret_fut=True, arg_handle=True, modes=None,\
	         n_jobs=None, backend=None, verbose=0, timeout=None, pre_dispatch='2 * n_jobs',\
	         batch_size='auto', temp_folder=None, max_nbytes='1M',\
	         mmap_mode='r', prefer=None, require=None):
		#active_backend= get_active_backend()
		# nesting_level = active_backend.nesting_level

		# if backend is None:
		# 	backend = active_backend

		# else:
		# 	try:
		# 		backend_factory = BACKENDS[backend]
		# 	except KeyError as e:
		# 		raise ValueError("Invalid backend: %s, expected one of %r"
  #                                % (backend, sorted(BACKENDS.keys()))) from e
		# 	backend = backend_factory(nesting_level=nesting_level)

		if n_jobs is None:
			n_jobs = 1

		self.mode=mode
		self.perfmodel=perfmodel
		self.end_msg=end_msg
		self.name=name
		self.synchronous=synchronous
		self.priority=priority
		self.color=color
		self.flops=flops
		self.ret_handle=ret_handle
		self.ret_fut=ret_fut
		self.arg_handle=arg_handle
		self.modes=modes
		self.n_jobs=n_jobs
		self._backend=backend

	def print_progress(self):
		#todo
		print("", starpupy.task_nsubmitted())

	def __call__(self,iterable):
		#generate the dictionary of task_submit
		dict_task={'name': self.name, 'synchronous': self.synchronous, 'priority': self.priority, 'color': self.color, 'flops': self.flops, 'perfmodel': self.perfmodel, 'ret_handle': self.ret_handle, 'ret_fut': self.ret_fut, 'arg_handle': self.arg_handle, 'modes': self.modes}
		if hasattr(self._backend, 'start_call'):
			self._backend.start_call()
		# the mode normal, user can call the function directly without using async
		if self.mode=="normal":
			async def asy_main():
				L_fut=future_generator(iterable, self.n_jobs, dict_task)
				res=[]
				for i in range(len(L_fut)):
					L_res=await L_fut[i]
					if L_res is None:
						res=None
					else:
						res.extend(L_res)
				#print(res)
				#print("type of result is", type(res))
				return res
			#asyncio.run(asy_main())
			#retVal=asy_main
			#loop = asyncio.get_event_loop()
			if(loop.is_running() and not has_nest):
				raise starpupy.error("Can't find \'nest_asyncio\' module (consider running \"pip3 install nest_asyncio\" or try to remove \"-m asyncio\" when starting Python interpreter)")

			results = loop.run_until_complete(asy_main())
			retVal = results
		# the mode future, user needs to use asyncio module and await the Future result in main function
		elif self.mode=="future":
			L_fut=future_generator(iterable, self.n_jobs, dict_task)
			fut=asyncio.gather(*L_fut)
			if self.end_msg!=None:
				fut.add_done_callback(functools.partial(print, self.end_msg))
			retVal=fut
		if hasattr(self._backend, 'stop_call'):
			self._backend.stop_call()
		return retVal

def delayed(function):
	def delayed_function(*args, **kwargs):
		return function, args, kwargs
	return delayed_function


######################################################################
__version__ = jl.__version__

class Memory(jl.Memory):
	def __init__(self,location=None, backend='local', cachedir=None,
                 mmap_mode=None, compress=False, verbose=1, bytes_limit=None,
                 backend_options=None):
		super(Memory, self).__init__(location=None, backend='local', cachedir=None,
                 mmap_mode=None, compress=False, verbose=1, bytes_limit=None,
                 backend_options=None)


def dump(value, filename, compress=0, protocol=None, cache_size=None):
	return jl.dump(value, filename, compress, protocol, cache_size)

def load(filename, mmap_mode=None):
	return jl.load(filename, mmap_mode)

def hash(obj, hash_name='md5', coerce_mmap=False):
	return jl.hash(obj, hash_name, coerce_mmap)

def register_compressor(compressor_name, compressor, force=False):
	return jl.register_compressor(compressor_name, compressor, force)

def effective_n_jobs(n_jobs=-1):
	return cpu_count()

def get_active_backend():
	backend_and_jobs = getattr(_backend, 'backend_and_jobs', None)
	if backend_and_jobs is not None:
		backend,n_jobs=backend_and_jobs
		return backend
	backend = BACKENDS[loky](nesting_level=0)
	return backend

class parallel_backend(object):
	def __init__(self, backend, n_jobs=-1, inner_max_num_threads=None,
                 **backend_params):
		if isinstance(backend, str):
			backend = BACKENDS[backend](**backend_params)

		current_backend_and_jobs = getattr(_backend, 'backend_and_jobs', None)
		if backend.nesting_level is None:
			if current_backend_and_jobs is None:
				nesting_level = 0
			else:
				nesting_level = current_backend_and_jobs[0].nesting_level

			backend.nesting_level = nesting_level

		# Save the backends info and set the active backend
		self.old_backend_and_jobs = current_backend_and_jobs
		self.new_backend_and_jobs = (backend, n_jobs)

		_backend.backend_and_jobs = (backend, n_jobs)

	def __enter__(self):
		return self.new_backend_and_jobs

	def __exit__(self, type, value, traceback):
		self.unregister()

	def unregister(self):
		if self.old_backend_and_jobs is None:
			if getattr(_backend, 'backend_and_jobs', None) is not None:
				del _backend.backend_and_jobs
		else:
			_backend.backend_and_jobs = self.old_backend_and_jobs

def register_parallel_backend(name, factory):
	BACKENDS[name] = factory
