diff --git a/openml/_api_calls.py b/openml/_api_calls.py index 888afa18e..c357dc3d0 100644 --- a/openml/_api_calls.py +++ b/openml/_api_calls.py @@ -1,15 +1,15 @@ # License: BSD 3-Clause import time +import hashlib import logging import requests -import warnings import xmltodict -from typing import Dict +from typing import Dict, Optional from . import config from .exceptions import (OpenMLServerError, OpenMLServerException, - OpenMLServerNoResult) + OpenMLServerNoResult, OpenMLHashException) def _perform_api_call(call, request_method, data=None, file_elements=None): @@ -47,20 +47,105 @@ def _perform_api_call(call, request_method, data=None, file_elements=None): url = url.replace('=', '%3d') logging.info('Starting [%s] request for the URL %s', request_method, url) start = time.time() + if file_elements is not None: if request_method != 'post': - raise ValueError('request method must be post when file elements ' - 'are present') - response = _read_url_files(url, data=data, file_elements=file_elements) + raise ValueError('request method must be post when file elements are present') + response = __read_url_files(url, data=data, file_elements=file_elements) else: - response = _read_url(url, request_method, data) + response = __read_url(url, request_method, data) + + __check_response(response, url, file_elements) + logging.info( '%.7fs taken for [%s] request for the URL %s', time.time() - start, request_method, url, ) - return response + return response.text + + +def _download_text_file(source: str, + output_path: Optional[str] = None, + md5_checksum: str = None, + exists_ok: bool = True, + encoding: str = 'utf8', + ) -> Optional[str]: + """ Download the text file at `source` and store it in `output_path`. + + By default, do nothing if a file already exists in `output_path`. + The downloaded file can be checked against an expected md5 checksum. + + Parameters + ---------- + source : str + url of the file to be downloaded + output_path : str, (optional) + full path, including filename, of where the file should be stored. If ``None``, + this function returns the downloaded file as string. + md5_checksum : str, optional (default=None) + If not None, should be a string of hexidecimal digits of the expected digest value. + exists_ok : bool, optional (default=True) + If False, raise an FileExistsError if there already exists a file at `output_path`. + encoding : str, optional (default='utf8') + The encoding with which the file should be stored. + """ + if output_path is not None: + try: + with open(output_path, encoding=encoding): + if exists_ok: + return None + else: + raise FileExistsError + except FileNotFoundError: + pass + + logging.info('Starting [%s] request for the URL %s', 'get', source) + start = time.time() + response = __read_url(source, request_method='get') + __check_response(response, source, None) + downloaded_file = response.text + + if md5_checksum is not None: + md5 = hashlib.md5() + md5.update(downloaded_file.encode('utf-8')) + md5_checksum_download = md5.hexdigest() + if md5_checksum != md5_checksum_download: + raise OpenMLHashException( + 'Checksum {} of downloaded file is unequal to the expected checksum {}.' + .format(md5_checksum_download, md5_checksum)) + + if output_path is None: + logging.info( + '%.7fs taken for [%s] request for the URL %s', + time.time() - start, + 'get', + source, + ) + return downloaded_file + + else: + with open(output_path, "w", encoding=encoding) as fh: + fh.write(downloaded_file) + + logging.info( + '%.7fs taken for [%s] request for the URL %s', + time.time() - start, + 'get', + source, + ) + + del downloaded_file + return None + + +def __check_response(response, url, file_elements): + if response.status_code != 200: + raise __parse_server_exception(response, url, file_elements=file_elements) + elif 'Content-Encoding' not in response.headers or \ + response.headers['Content-Encoding'] != 'gzip': + logging.warning('Received uncompressed content from OpenML for {}.'.format(url)) def _file_id_to_url(file_id, filename=None): @@ -75,7 +160,7 @@ def _file_id_to_url(file_id, filename=None): return url -def _read_url_files(url, data=None, file_elements=None): +def __read_url_files(url, data=None, file_elements=None): """do a post request to url with data and sending file_elements as files""" @@ -85,37 +170,24 @@ def _read_url_files(url, data=None, file_elements=None): file_elements = {} # Using requests.post sets header 'Accept-encoding' automatically to # 'gzip,deflate' - response = send_request( + response = __send_request( request_method='post', url=url, data=data, files=file_elements, ) - if response.status_code != 200: - raise _parse_server_exception(response, url, file_elements=file_elements) - if 'Content-Encoding' not in response.headers or \ - response.headers['Content-Encoding'] != 'gzip': - warnings.warn('Received uncompressed content from OpenML for {}.' - .format(url)) - return response.text + return response -def _read_url(url, request_method, data=None): +def __read_url(url, request_method, data=None): data = {} if data is None else data if config.apikey is not None: data['api_key'] = config.apikey - response = send_request(request_method=request_method, url=url, data=data) - if response.status_code != 200: - raise _parse_server_exception(response, url, file_elements=None) - if 'Content-Encoding' not in response.headers or \ - response.headers['Content-Encoding'] != 'gzip': - warnings.warn('Received uncompressed content from OpenML for {}.' - .format(url)) - return response.text + return __send_request(request_method=request_method, url=url, data=data) -def send_request( +def __send_request( request_method, url, data, @@ -149,16 +221,19 @@ def send_request( return response -def _parse_server_exception( +def __parse_server_exception( response: requests.Response, url: str, file_elements: Dict, ) -> OpenMLServerError: - # OpenML has a sophisticated error system - # where information about failures is provided. try to parse this + + if response.status_code == 414: + raise OpenMLServerError('URI too long! ({})'.format(url)) try: server_exception = xmltodict.parse(response.text) except Exception: + # OpenML has a sophisticated error system + # where information about failures is provided. try to parse this raise OpenMLServerError( 'Unexpected server error when calling {}. Please contact the developers!\n' 'Status code: {}\n{}'.format(url, response.status_code, response.text)) diff --git a/openml/datasets/functions.py b/openml/datasets/functions.py index e85c55aa3..657fbc7c6 100644 --- a/openml/datasets/functions.py +++ b/openml/datasets/functions.py @@ -886,7 +886,7 @@ def _get_dataset_arff(description: Union[Dict, OpenMLDataset], output_file_path = os.path.join(cache_directory, "dataset.arff") try: - openml.utils._download_text_file( + openml._api_calls._download_text_file( source=url, output_path=output_file_path, md5_checksum=md5_checksum_fixture @@ -1038,13 +1038,11 @@ def _get_online_dataset_arff(dataset_id): str A string representation of an ARFF file. """ - dataset_xml = openml._api_calls._perform_api_call("data/%d" % dataset_id, - 'get') + dataset_xml = openml._api_calls._perform_api_call("data/%d" % dataset_id, 'get') # build a dict from the xml. # use the url from the dataset description and return the ARFF string - return openml._api_calls._read_url( + return openml._api_calls._download_text_file( xmltodict.parse(dataset_xml)['oml:data_set_description']['oml:url'], - request_method='get' ) diff --git a/openml/runs/run.py b/openml/runs/run.py index 140347cc4..910801971 100644 --- a/openml/runs/run.py +++ b/openml/runs/run.py @@ -327,8 +327,7 @@ def get_metric_fn(self, sklearn_fn, kwargs=None): predictions_file_url = openml._api_calls._file_id_to_url( self.output_files['predictions'], 'predictions.arff', ) - response = openml._api_calls._read_url(predictions_file_url, - request_method='get') + response = openml._api_calls._download_text_file(predictions_file_url) predictions_arff = arff.loads(response) # TODO: make this a stream reader else: diff --git a/openml/tasks/task.py b/openml/tasks/task.py index 0b79c2eca..72c12bab5 100644 --- a/openml/tasks/task.py +++ b/openml/tasks/task.py @@ -116,12 +116,10 @@ def _download_split(self, cache_file: str): pass except (OSError, IOError): split_url = self.estimation_procedure["data_splits_url"] - split_arff = openml._api_calls._read_url(split_url, - request_method='get') - - with io.open(cache_file, "w", encoding='utf8') as fh: - fh.write(split_arff) - del split_arff + openml._api_calls._download_text_file( + source=str(split_url), + output_path=cache_file, + ) def download_split(self) -> OpenMLSplit: """Download the OpenML split for a given task. diff --git a/openml/utils.py b/openml/utils.py index 09a0f6a83..2815f1afd 100644 --- a/openml/utils.py +++ b/openml/utils.py @@ -1,7 +1,6 @@ # License: BSD 3-Clause import os -import hashlib import xmltodict import shutil from typing import TYPE_CHECKING, List, Tuple, Union, Type @@ -366,53 +365,3 @@ def _create_lockfiles_dir(): except OSError: pass return dir - - -def _download_text_file(source: str, - output_path: str, - md5_checksum: str = None, - exists_ok: bool = True, - encoding: str = 'utf8', - ) -> None: - """ Download the text file at `source` and store it in `output_path`. - - By default, do nothing if a file already exists in `output_path`. - The downloaded file can be checked against an expected md5 checksum. - - Parameters - ---------- - source : str - url of the file to be downloaded - output_path : str - full path, including filename, of where the file should be stored. - md5_checksum : str, optional (default=None) - If not None, should be a string of hexidecimal digits of the expected digest value. - exists_ok : bool, optional (default=True) - If False, raise an FileExistsError if there already exists a file at `output_path`. - encoding : str, optional (default='utf8') - The encoding with which the file should be stored. - """ - try: - with open(output_path, encoding=encoding): - if exists_ok: - return - else: - raise FileExistsError - except FileNotFoundError: - pass - - downloaded_file = openml._api_calls._read_url(source, request_method='get') - - if md5_checksum is not None: - md5 = hashlib.md5() - md5.update(downloaded_file.encode('utf-8')) - md5_checksum_download = md5.hexdigest() - if md5_checksum != md5_checksum_download: - raise openml.exceptions.OpenMLHashException( - 'Checksum {} of downloaded file is unequal to the expected checksum {}.' - .format(md5_checksum_download, md5_checksum)) - - with open(output_path, "w", encoding=encoding) as fh: - fh.write(downloaded_file) - - del downloaded_file diff --git a/tests/test_openml/test_api_calls.py b/tests/test_openml/test_api_calls.py new file mode 100644 index 000000000..1748608bb --- /dev/null +++ b/tests/test_openml/test_api_calls.py @@ -0,0 +1,12 @@ +import openml +import openml.testing + + +class TestConfig(openml.testing.TestBase): + + def test_too_long_uri(self): + with self.assertRaisesRegex( + openml.exceptions.OpenMLServerError, + 'URI too long!', + ): + openml.datasets.list_datasets(data_id=list(range(10000))) diff --git a/tests/test_runs/test_run_functions.py b/tests/test_runs/test_run_functions.py index 2773bc8d9..fe8aab808 100644 --- a/tests/test_runs/test_run_functions.py +++ b/tests/test_runs/test_run_functions.py @@ -119,8 +119,7 @@ def _rerun_model_and_compare_predictions(self, run_id, model_prime, seed): # downloads the predictions of the old task file_id = run.output_files['predictions'] predictions_url = openml._api_calls._file_id_to_url(file_id) - response = openml._api_calls._read_url(predictions_url, - request_method='get') + response = openml._api_calls._download_text_file(predictions_url) predictions = arff.loads(response) run_prime = openml.runs.run_model_on_task( model=model_prime, diff --git a/tests/test_utils/test_utils.py b/tests/test_utils/test_utils.py index de2d18981..152dd4dba 100644 --- a/tests/test_utils/test_utils.py +++ b/tests/test_utils/test_utils.py @@ -16,7 +16,7 @@ class OpenMLTaskTest(TestBase): def mocked_perform_api_call(call, request_method): # TODO: JvR: Why is this not a staticmethod? url = openml.config.server + '/' + call - return openml._api_calls._read_url(url, request_method=request_method) + return openml._api_calls._download_text_file(url) def test_list_all(self): openml.utils._list_all(listing_call=openml.tasks.functions._list_tasks)