#===============================================================================
# Copyright 2021 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#===============================================================================

from onedal import _backend, _is_dpc_backend
import sys

oneapi_is_available = 'daal4py.oneapi' in sys.modules
if oneapi_is_available:
    from daal4py.oneapi import _get_sycl_ctxt, sycl_execution_context


def _get_policy(queue, *data):
    data_queue = _get_queue(*data)
    if _is_dpc_backend:
        if queue is None:
            if data_queue is None:
                return _HostInteropPolicy()
            return _DataParallelInteropPolicy(data_queue)
        return _DataParallelInteropPolicy(queue)
    assert data_queue is None and queue is None
    return _HostInteropPolicy()


def _get_queue(*data):
    if len(data) > 0 and hasattr(data[0], '__sycl_usm_array_interface__'):
        # Assume that all data reside on the same device
        return data[0].__sycl_usm_array_interface__['syclobj']
    return None


class _Daal4PyContextReset:
    def __init__(self):
        self._d4p_context = None
        self._host_context = None
        if oneapi_is_available:
            self._d4p_context = _get_sycl_ctxt()
            self._host_context = sycl_execution_context('cpu')
            self._host_context.apply()

    def __del__(self):
        if self._d4p_context:
            self._d4p_context.apply()


class _HostInteropPolicy(_backend.host_policy):
    def __init__(self):
        super().__init__()
        self._d4p_interop = _Daal4PyContextReset()


if _is_dpc_backend:
    class _DataParallelInteropPolicy(_backend.data_parallel_policy):
        def __init__(self, queue):
            self._queue = queue
            self._d4p_interop = _Daal4PyContextReset()
            if 'sklearnex' in sys.modules:
                from sklearnex._device_offload import DummySyclQueue
                if isinstance(queue, DummySyclQueue):
                    super().__init__(self._queue.sycl_device.get_filter_string())
                    return
            super().__init__(self._queue.addressof_ref())
