diff --git a/.github/workflows/codebuild-ci.yml b/.github/workflows/codebuild-ci.yml index 85919f0afe..8c6bd6b337 100644 --- a/.github/workflows/codebuild-ci.yml +++ b/.github/workflows/codebuild-ci.yml @@ -55,7 +55,7 @@ jobs: - name: Run Codestyle & Doc Tests uses: aws-actions/aws-codebuild-run-build@v1 with: - project-name: sagemaker-python-sdk-ci-codestyle-doc-tests + project-name: ${{ github.event.repository.name }}-ci-codestyle-doc-tests source-version-override: 'refs/pull/${{ github.event.pull_request.number }}/head^{${{ github.event.pull_request.head.sha }}}' unit-tests: runs-on: ubuntu-latest @@ -74,7 +74,7 @@ jobs: - name: Run Unit Tests uses: aws-actions/aws-codebuild-run-build@v1 with: - project-name: sagemaker-python-sdk-ci-unit-tests + project-name: ${{ github.event.repository.name }}-ci-unit-tests source-version-override: 'refs/pull/${{ github.event.pull_request.number }}/head^{${{ github.event.pull_request.head.sha }}}' env-vars-for-codebuild: | PY_VERSION @@ -93,5 +93,5 @@ jobs: - name: Run Integ Tests uses: aws-actions/aws-codebuild-run-build@v1 with: - project-name: sagemaker-python-sdk-ci-integ-tests + project-name: ${{ github.event.repository.name }}-ci-integ-tests source-version-override: 'refs/pull/${{ github.event.pull_request.number }}/head^{${{ github.event.pull_request.head.sha }}}' diff --git a/CHANGELOG.md b/CHANGELOG.md index 5d85261dd7..63e5114f10 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,26 @@ # Changelog +## v2.221.1 (2024-05-22) + +### Bug Fixes and Other Changes + + * Convert pytorchddp distribution to smdistributed distribution + * Add tei cpu image + +## v2.221.0 (2024-05-20) + +### Features + + * onboard tei image config to pysdk + +### Bug Fixes and Other Changes + + * JS Model with non-TGI/non-DJL deployment failure + * cover tei with image_uris.retrieve API + * Add more debuging + * model builder limited container support for endpoint mode. + * Image URI should take precedence for HF models + ## v2.220.0 (2024-05-15) ### Features diff --git a/VERSION b/VERSION index bd29d7bb5b..e55266069e 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -2.220.1.dev0 +2.221.2.dev0 diff --git a/requirements/extras/test_requirements.txt b/requirements/extras/test_requirements.txt index ad6322365f..889ff72779 100644 --- a/requirements/extras/test_requirements.txt +++ b/requirements/extras/test_requirements.txt @@ -12,7 +12,7 @@ awslogs==0.14.0 black==24.3.0 stopit==1.1.2 # Update tox.ini to have correct version of airflow constraints file -apache-airflow==2.9.0 +apache-airflow==2.9.1 apache-airflow-providers-amazon==7.2.1 attrs>=23.1.0,<24 fabric==2.6.0 diff --git a/src/sagemaker/enums.py b/src/sagemaker/enums.py index 5b4d0d6790..f02b275cbe 100644 --- a/src/sagemaker/enums.py +++ b/src/sagemaker/enums.py @@ -28,3 +28,15 @@ class EndpointType(Enum): INFERENCE_COMPONENT_BASED = ( "InferenceComponentBased" # Amazon SageMaker Inference Component Based Endpoint ) + + +class RoutingStrategy(Enum): + """Strategy for routing https traffics.""" + + RANDOM = "RANDOM" + """The endpoint routes each request to a randomly chosen instance. + """ + LEAST_OUTSTANDING_REQUESTS = "LEAST_OUTSTANDING_REQUESTS" + """The endpoint routes requests to the specific instances that have + more capacity to process them. + """ diff --git a/src/sagemaker/fw_utils.py b/src/sagemaker/fw_utils.py index 33018becdd..be3658365a 100644 --- a/src/sagemaker/fw_utils.py +++ b/src/sagemaker/fw_utils.py @@ -145,22 +145,6 @@ ], } -PYTORCHDDP_SUPPORTED_FRAMEWORK_VERSIONS = [ - "1.10", - "1.10.0", - "1.10.2", - "1.11", - "1.11.0", - "1.12", - "1.12.0", - "1.12.1", - "1.13.1", - "2.0.0", - "2.0.1", - "2.1.0", - "2.2.0", -] - TORCH_DISTRIBUTED_GPU_SUPPORTED_FRAMEWORK_VERSIONS = [ "1.13.1", "2.0.0", @@ -795,7 +779,6 @@ def _validate_smdataparallel_args( Raises: ValueError: if - (`instance_type` is not in SM_DATAPARALLEL_SUPPORTED_INSTANCE_TYPES or `py_version` is not python3 or `framework_version` is not in SM_DATAPARALLEL_SUPPORTED_FRAMEWORK_VERSION """ @@ -806,17 +789,10 @@ def _validate_smdataparallel_args( if not smdataparallel_enabled: return - is_instance_type_supported = instance_type in SM_DATAPARALLEL_SUPPORTED_INSTANCE_TYPES - err_msg = "" - if not is_instance_type_supported: - # instance_type is required - err_msg += ( - f"Provided instance_type {instance_type} is not supported by smdataparallel.\n" - "Please specify one of the supported instance types:" - f"{SM_DATAPARALLEL_SUPPORTED_INSTANCE_TYPES}\n" - ) + if not instance_type: + err_msg += "Please specify an instance_type for smdataparallel.\n" if not image_uri: # ignore framework_version & py_version if image_uri is set @@ -928,13 +904,6 @@ def validate_distribution( ) if framework_name and framework_name == "pytorch": # We need to validate only for PyTorch framework - validate_pytorch_distribution( - distribution=validated_distribution, - framework_name=framework_name, - framework_version=framework_version, - py_version=py_version, - image_uri=image_uri, - ) validate_torch_distributed_distribution( instance_type=instance_type, distribution=validated_distribution, @@ -968,13 +937,6 @@ def validate_distribution( ) if framework_name and framework_name == "pytorch": # We need to validate only for PyTorch framework - validate_pytorch_distribution( - distribution=validated_distribution, - framework_name=framework_name, - framework_version=framework_version, - py_version=py_version, - image_uri=image_uri, - ) validate_torch_distributed_distribution( instance_type=instance_type, distribution=validated_distribution, @@ -1023,63 +985,6 @@ def validate_distribution_for_instance_type(instance_type, distribution): raise ValueError(err_msg) -def validate_pytorch_distribution( - distribution, framework_name, framework_version, py_version, image_uri -): - """Check if pytorch distribution strategy is correctly invoked by the user. - - Args: - distribution (dict): A dictionary with information to enable distributed training. - (Defaults to None if distributed training is not enabled.) For example: - - .. code:: python - - { - "pytorchddp": { - "enabled": True - } - } - framework_name (str): A string representing the name of framework selected. - framework_version (str): A string representing the framework version selected. - py_version (str): A string representing the python version selected. - image_uri (str): A string representing a Docker image URI. - - Raises: - ValueError: if - `py_version` is not python3 or - `framework_version` is not in PYTORCHDDP_SUPPORTED_FRAMEWORK_VERSIONS - """ - if framework_name and framework_name != "pytorch": - # We need to validate only for PyTorch framework - return - - pytorch_ddp_enabled = False - if "pytorchddp" in distribution: - pytorch_ddp_enabled = distribution.get("pytorchddp").get("enabled", False) - if not pytorch_ddp_enabled: - # Distribution strategy other than pytorchddp is selected - return - - err_msg = "" - if not image_uri: - # ignore framework_version and py_version if image_uri is set - # in case image_uri is not set, then both are mandatory - if framework_version not in PYTORCHDDP_SUPPORTED_FRAMEWORK_VERSIONS: - err_msg += ( - f"Provided framework_version {framework_version} is not supported by" - " pytorchddp.\n" - "Please specify one of the supported framework versions:" - f" {PYTORCHDDP_SUPPORTED_FRAMEWORK_VERSIONS} \n" - ) - if "py3" not in py_version: - err_msg += ( - f"Provided py_version {py_version} is not supported by pytorchddp.\n" - "Please specify py_version>=py3" - ) - if err_msg: - raise ValueError(err_msg) - - def validate_torch_distributed_distribution( instance_type, distribution, diff --git a/src/sagemaker/huggingface/llm_utils.py b/src/sagemaker/huggingface/llm_utils.py index de5e624dbc..9927d1d293 100644 --- a/src/sagemaker/huggingface/llm_utils.py +++ b/src/sagemaker/huggingface/llm_utils.py @@ -65,6 +65,20 @@ def get_huggingface_llm_image_uri( image_scope="inference", inference_tool="neuronx", ) + if backend == "huggingface-tei": + return image_uris.retrieve( + "huggingface-tei", + region=region, + version=version, + image_scope="inference", + ) + if backend == "huggingface-tei-cpu": + return image_uris.retrieve( + "huggingface-tei-cpu", + region=region, + version=version, + image_scope="inference", + ) if backend == "lmi": version = version or "0.24.0" return image_uris.retrieve(framework="djl-deepspeed", region=region, version=version) diff --git a/src/sagemaker/huggingface/model.py b/src/sagemaker/huggingface/model.py index f71dca0ac8..662baecae6 100644 --- a/src/sagemaker/huggingface/model.py +++ b/src/sagemaker/huggingface/model.py @@ -334,6 +334,7 @@ def deploy( endpoint_type=kwargs.get("endpoint_type", None), resources=kwargs.get("resources", None), managed_instance_scaling=kwargs.get("managed_instance_scaling", None), + routing_config=kwargs.get("routing_config", None), ) def register( diff --git a/src/sagemaker/image_uri_config/djl-lmi.json b/src/sagemaker/image_uri_config/djl-lmi.json new file mode 100644 index 0000000000..9d2cdd699a --- /dev/null +++ b/src/sagemaker/image_uri_config/djl-lmi.json @@ -0,0 +1,39 @@ +{ + "scope": [ + "inference" + ], + "versions": { + "0.28.0": { + "registries": { + "af-south-1": "626614931356", + "il-central-1": "780543022126", + "ap-east-1": "871362719292", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ca-central-1": "763104351884", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-north-1": "763104351884", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "eu-south-1": "692866216735", + "me-south-1": "217643126080", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-west-1": "763104351884", + "us-west-2": "763104351884", + "ca-west-1": "204538143572" + }, + "repository": "djl-inference", + "tag_prefix": "0.28.0-lmi10.0.0-cu124" + } + } +} \ No newline at end of file diff --git a/src/sagemaker/image_uri_config/djl-neuronx.json b/src/sagemaker/image_uri_config/djl-neuronx.json index a63acc87e4..6038946e28 100644 --- a/src/sagemaker/image_uri_config/djl-neuronx.json +++ b/src/sagemaker/image_uri_config/djl-neuronx.json @@ -3,6 +3,24 @@ "inference" ], "versions": { + "0.28.0": { + "registries": { + "ap-northeast-1": "763104351884", + "ap-south-1": "763104351884", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "eu-central-1": "763104351884", + "eu-west-1": "763104351884", + "eu-west-3": "763104351884", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-west-2": "763104351884", + "ca-west-1": "204538143572" + }, + "repository": "djl-inference", + "tag_prefix": "0.28.0-neuronx-sdk2.18.2" + }, "0.27.0": { "registries": { "ap-northeast-1": "763104351884", diff --git a/src/sagemaker/image_uri_config/djl-tensorrtllm.json b/src/sagemaker/image_uri_config/djl-tensorrtllm.json index e125cbd419..6cde6109bb 100644 --- a/src/sagemaker/image_uri_config/djl-tensorrtllm.json +++ b/src/sagemaker/image_uri_config/djl-tensorrtllm.json @@ -3,6 +3,38 @@ "inference" ], "versions": { + "0.28.0": { + "registries": { + "af-south-1": "626614931356", + "il-central-1": "780543022126", + "ap-east-1": "871362719292", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ca-central-1": "763104351884", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-north-1": "763104351884", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "eu-south-1": "692866216735", + "me-south-1": "217643126080", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-west-1": "763104351884", + "us-west-2": "763104351884", + "ca-west-1": "204538143572" + }, + "repository": "djl-inference", + "tag_prefix": "0.28.0-tensorrtllm0.9.0-cu122" + }, "0.27.0": { "registries": { "af-south-1": "626614931356", diff --git a/src/sagemaker/image_uri_config/huggingface-tei-cpu.json b/src/sagemaker/image_uri_config/huggingface-tei-cpu.json new file mode 100644 index 0000000000..d68b0d6307 --- /dev/null +++ b/src/sagemaker/image_uri_config/huggingface-tei-cpu.json @@ -0,0 +1,59 @@ +{ + "inference": { + "processors": [ + "cpu" + ], + "version_aliases": { + "1.2": "1.2.3" + }, + "versions": { + "1.2.3": { + "py_versions": [ + "py310" + ], + "registries": { + "af-south-1": "510948584623", + "ap-east-1": "651117190479", + "ap-northeast-1": "354813040037", + "ap-northeast-2": "366743142698", + "ap-northeast-3": "867004704886", + "ap-south-1": "720646828776", + "ap-south-2": "628508329040", + "ap-southeast-1": "121021644041", + "ap-southeast-2": "783357654285", + "ap-southeast-3": "951798379941", + "ap-southeast-4": "106583098589", + "ca-central-1": "341280168497", + "ca-west-1": "190319476487", + "cn-north-1": "450853457545", + "cn-northwest-1": "451049120500", + "eu-central-1": "492215442770", + "eu-central-2": "680994064768", + "eu-north-1": "662702820516", + "eu-south-1": "978288397137", + "eu-south-2": "104374241257", + "eu-west-1": "141502667606", + "eu-west-2": "764974769150", + "eu-west-3": "659782779980", + "il-central-1": "898809789911", + "me-central-1": "272398656194", + "me-south-1": "801668240914", + "sa-east-1": "737474898029", + "us-east-1": "683313688378", + "us-east-2": "257758044811", + "us-gov-east-1": "237065988967", + "us-gov-west-1": "414596584902", + "us-iso-east-1": "833128469047", + "us-isob-east-1": "281123927165", + "us-west-1": "746614075791", + "us-west-2": "246618743249" + }, + "tag_prefix": "2.0.1-tei1.2.3", + "repository": "tei-cpu", + "container_version": { + "cpu": "ubuntu22.04" + } + } + } + } +} \ No newline at end of file diff --git a/src/sagemaker/image_uri_config/huggingface-tei.json b/src/sagemaker/image_uri_config/huggingface-tei.json new file mode 100644 index 0000000000..b7c597df18 --- /dev/null +++ b/src/sagemaker/image_uri_config/huggingface-tei.json @@ -0,0 +1,59 @@ +{ + "inference": { + "processors": [ + "gpu" + ], + "version_aliases": { + "1.2": "1.2.3" + }, + "versions": { + "1.2.3": { + "py_versions": [ + "py310" + ], + "registries": { + "af-south-1": "510948584623", + "ap-east-1": "651117190479", + "ap-northeast-1": "354813040037", + "ap-northeast-2": "366743142698", + "ap-northeast-3": "867004704886", + "ap-south-1": "720646828776", + "ap-south-2": "628508329040", + "ap-southeast-1": "121021644041", + "ap-southeast-2": "783357654285", + "ap-southeast-3": "951798379941", + "ap-southeast-4": "106583098589", + "ca-central-1": "341280168497", + "ca-west-1": "190319476487", + "cn-north-1": "450853457545", + "cn-northwest-1": "451049120500", + "eu-central-1": "492215442770", + "eu-central-2": "680994064768", + "eu-north-1": "662702820516", + "eu-south-1": "978288397137", + "eu-south-2": "104374241257", + "eu-west-1": "141502667606", + "eu-west-2": "764974769150", + "eu-west-3": "659782779980", + "il-central-1": "898809789911", + "me-central-1": "272398656194", + "me-south-1": "801668240914", + "sa-east-1": "737474898029", + "us-east-1": "683313688378", + "us-east-2": "257758044811", + "us-gov-east-1": "237065988967", + "us-gov-west-1": "414596584902", + "us-iso-east-1": "833128469047", + "us-isob-east-1": "281123927165", + "us-west-1": "746614075791", + "us-west-2": "246618743249" + }, + "tag_prefix": "2.0.1-tei1.2.3", + "repository": "tei", + "container_version": { + "gpu": "cu122-ubuntu22.04" + } + } + } + } +} \ No newline at end of file diff --git a/src/sagemaker/image_uri_config/pytorch.json b/src/sagemaker/image_uri_config/pytorch.json index f04f9c6b69..b846d51246 100644 --- a/src/sagemaker/image_uri_config/pytorch.json +++ b/src/sagemaker/image_uri_config/pytorch.json @@ -1315,7 +1315,8 @@ "1.13": "1.13.1", "2.0": "2.0.1", "2.1": "2.1.0", - "2.2": "2.2.0" + "2.2": "2.2.0", + "2.3": "2.3.0" }, "versions": { "0.4.0": { @@ -2288,6 +2289,47 @@ "us-west-2": "763104351884" }, "repository": "pytorch-training" + }, + "2.3.0": { + "py_versions": [ + "py311" + ], + "registries": { + "af-south-1": "626614931356", + "ap-east-1": "871362719292", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-south-2": "772153158452", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ap-southeast-4": "457447274322", + "ca-central-1": "763104351884", + "ca-west-1": "204538143572", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-central-2": "380420809688", + "eu-north-1": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "il-central-1": "780543022126", + "me-central-1": "914824155844", + "me-south-1": "217643126080", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", + "us-west-1": "763104351884", + "us-west-2": "763104351884" + }, + "repository": "pytorch-training" } } } diff --git a/src/sagemaker/image_uris.py b/src/sagemaker/image_uris.py index 97471f2c41..743f6b1f99 100644 --- a/src/sagemaker/image_uris.py +++ b/src/sagemaker/image_uris.py @@ -37,6 +37,8 @@ ECR_URI_TEMPLATE = "{registry}.dkr.{hostname}/{repository}" HUGGING_FACE_FRAMEWORK = "huggingface" HUGGING_FACE_LLM_FRAMEWORK = "huggingface-llm" +HUGGING_FACE_TEI_GPU_FRAMEWORK = "huggingface-tei" +HUGGING_FACE_TEI_CPU_FRAMEWORK = "huggingface-tei-cpu" HUGGING_FACE_LLM_NEURONX_FRAMEWORK = "huggingface-llm-neuronx" XGBOOST_FRAMEWORK = "xgboost" SKLEARN_FRAMEWORK = "sklearn" @@ -480,6 +482,8 @@ def _validate_version_and_set_if_needed(version, config, framework): if version is None and framework in [ DATA_WRANGLER_FRAMEWORK, HUGGING_FACE_LLM_FRAMEWORK, + HUGGING_FACE_TEI_GPU_FRAMEWORK, + HUGGING_FACE_TEI_CPU_FRAMEWORK, HUGGING_FACE_LLM_NEURONX_FRAMEWORK, STABILITYAI_FRAMEWORK, ]: diff --git a/src/sagemaker/jumpstart/factory/model.py b/src/sagemaker/jumpstart/factory/model.py index 79a7b18788..0bae5955e2 100644 --- a/src/sagemaker/jumpstart/factory/model.py +++ b/src/sagemaker/jumpstart/factory/model.py @@ -678,6 +678,7 @@ def get_deploy_kwargs( endpoint_type: Optional[EndpointType] = None, training_config_name: Optional[str] = None, config_name: Optional[str] = None, + routing_config: Optional[Dict[str, Any]] = None, ) -> JumpStartModelDeployKwargs: """Returns kwargs required to call `deploy` on `sagemaker.estimator.Model` object.""" @@ -710,6 +711,7 @@ def get_deploy_kwargs( endpoint_logging=endpoint_logging, resources=resources, config_name=config_name, + routing_config=routing_config, ) deploy_kwargs = _add_sagemaker_session_to_kwargs(kwargs=deploy_kwargs) @@ -800,7 +802,6 @@ def get_register_kwargs( data_input_configuration=data_input_configuration, skip_model_validation=skip_model_validation, source_uri=source_uri, - config_name=config_name, ) model_specs = verify_model_region_and_return_specs( diff --git a/src/sagemaker/jumpstart/model.py b/src/sagemaker/jumpstart/model.py index 81efc1f17a..f72a3140dc 100644 --- a/src/sagemaker/jumpstart/model.py +++ b/src/sagemaker/jumpstart/model.py @@ -583,6 +583,7 @@ def deploy( resources: Optional[ResourceRequirements] = None, managed_instance_scaling: Optional[str] = None, endpoint_type: EndpointType = EndpointType.MODEL_BASED, + routing_config: Optional[Dict[str, Any]] = None, ) -> PredictorBase: """Creates endpoint by calling base ``Model`` class `deploy` method. @@ -677,6 +678,8 @@ def deploy( endpoint. endpoint_type (EndpointType): The type of endpoint used to deploy models. (Default: EndpointType.MODEL_BASED). + routing_config (Optional[Dict]): Settings the control how the endpoint routes + incoming traffic to the instances that the endpoint hosts. Raises: MarketplaceModelSubscriptionError: If the caller is not subscribed to the model. @@ -713,6 +716,7 @@ def deploy( endpoint_type=endpoint_type, model_type=self.model_type, config_name=self.config_name, + routing_config=routing_config, ) if ( self.model_type == JumpStartModelType.PROPRIETARY diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index 04e8b91e26..f197421d65 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -1647,6 +1647,7 @@ class JumpStartModelDeployKwargs(JumpStartKwargs): "resources", "endpoint_type", "config_name", + "routing_config", ] SERIALIZATION_EXCLUSION_SET = { @@ -1693,6 +1694,7 @@ def __init__( resources: Optional[ResourceRequirements] = None, endpoint_type: Optional[EndpointType] = None, config_name: Optional[str] = None, + routing_config: Optional[Dict[str, Any]] = None, ) -> None: """Instantiates JumpStartModelDeployKwargs object.""" @@ -1726,6 +1728,7 @@ def __init__( self.resources = resources self.endpoint_type = endpoint_type self.config_name = config_name + self.routing_config = routing_config class JumpStartEstimatorInitKwargs(JumpStartKwargs): diff --git a/src/sagemaker/model.py b/src/sagemaker/model.py index fd21b6342e..1bb6cb2e5c 100644 --- a/src/sagemaker/model.py +++ b/src/sagemaker/model.py @@ -20,7 +20,7 @@ import os import re import copy -from typing import List, Dict, Optional, Union +from typing import List, Dict, Optional, Union, Any import sagemaker from sagemaker import ( @@ -66,6 +66,7 @@ resolve_nested_dict_value_from_config, format_tags, Tags, + _resolve_routing_config, ) from sagemaker.async_inference import AsyncInferenceConfig from sagemaker.predictor_async import AsyncPredictor @@ -1309,6 +1310,7 @@ def deploy( resources: Optional[ResourceRequirements] = None, endpoint_type: EndpointType = EndpointType.MODEL_BASED, managed_instance_scaling: Optional[str] = None, + routing_config: Optional[Dict[str, Any]] = None, **kwargs, ): """Deploy this ``Model`` to an ``Endpoint`` and optionally return a ``Predictor``. @@ -1406,6 +1408,15 @@ def deploy( Endpoint. (Default: None). endpoint_type (Optional[EndpointType]): The type of an endpoint used to deploy models. (Default: EndpointType.MODEL_BASED). + routing_config (Optional[Dict[str, Any]): Settings the control how the endpoint routes incoming + traffic to the instances that the endpoint hosts. + Currently, support dictionary key ``RoutingStrategy``. + + .. code:: python + + { + "RoutingStrategy": sagemaker.enums.RoutingStrategy.RANDOM + } Raises: ValueError: If arguments combination check failed in these circumstances: - If no role is specified or @@ -1458,6 +1469,8 @@ def deploy( if self.role is None: raise ValueError("Role can not be null for deploying a model") + routing_config = _resolve_routing_config(routing_config) + if ( inference_recommendation_id is not None or self.inference_recommender_job_results is not None @@ -1543,6 +1556,7 @@ def deploy( model_data_download_timeout=model_data_download_timeout, container_startup_health_check_timeout=container_startup_health_check_timeout, managed_instance_scaling=managed_instance_scaling_config, + routing_config=routing_config, ) self.sagemaker_session.endpoint_from_production_variants( @@ -1625,6 +1639,7 @@ def deploy( volume_size=volume_size, model_data_download_timeout=model_data_download_timeout, container_startup_health_check_timeout=container_startup_health_check_timeout, + routing_config=routing_config, ) if endpoint_name: self.endpoint_name = endpoint_name diff --git a/src/sagemaker/pytorch/estimator.py b/src/sagemaker/pytorch/estimator.py index a4e24d1ff0..412926279c 100644 --- a/src/sagemaker/pytorch/estimator.py +++ b/src/sagemaker/pytorch/estimator.py @@ -276,6 +276,20 @@ def __init__( kwargs["entry_point"] = entry_point if distribution is not None: + # rewrite pytorchddp to smdistributed + if "pytorchddp" in distribution: + if "smdistributed" in distribution: + raise ValueError( + "Cannot use both pytorchddp and smdistributed " + "distribution options together.", + distribution, + ) + + # convert pytorchddp distribution into smdistributed distribution + distribution = distribution.copy() + distribution["smdistributed"] = {"dataparallel": distribution["pytorchddp"]} + del distribution["pytorchddp"] + distribution = validate_distribution( distribution, self.instance_groups, diff --git a/src/sagemaker/serve/builder/jumpstart_builder.py b/src/sagemaker/serve/builder/jumpstart_builder.py index abe180644e..e8ef546f7a 100644 --- a/src/sagemaker/serve/builder/jumpstart_builder.py +++ b/src/sagemaker/serve/builder/jumpstart_builder.py @@ -23,6 +23,7 @@ from sagemaker import model_uris from sagemaker.serve.model_server.djl_serving.prepare import prepare_djl_js_resources from sagemaker.serve.model_server.djl_serving.utils import _get_admissible_tensor_parallel_degrees +from sagemaker.serve.model_server.multi_model_server.prepare import prepare_mms_js_resources from sagemaker.serve.model_server.tgi.prepare import prepare_tgi_js_resources, _create_dir_structure from sagemaker.serve.mode.function_pointers import Mode from sagemaker.serve.utils.exceptions import ( @@ -35,6 +36,7 @@ from sagemaker.serve.utils.predictors import ( DjlLocalModePredictor, TgiLocalModePredictor, + TransformersLocalModePredictor, ) from sagemaker.serve.utils.local_hardware import ( _get_nb_instance, @@ -90,6 +92,7 @@ def __init__(self): self.existing_properties = None self.prepared_for_tgi = None self.prepared_for_djl = None + self.prepared_for_mms = None self.schema_builder = None self.nb_instance_type = None self.ram_usage_model_load = None @@ -137,7 +140,11 @@ def _js_builder_deploy_wrapper(self, *args, **kwargs) -> Type[PredictorBase]: if overwrite_mode == Mode.SAGEMAKER_ENDPOINT: self.mode = self.pysdk_model.mode = Mode.SAGEMAKER_ENDPOINT - if not hasattr(self, "prepared_for_djl") or not hasattr(self, "prepared_for_tgi"): + if ( + not hasattr(self, "prepared_for_djl") + or not hasattr(self, "prepared_for_tgi") + or not hasattr(self, "prepared_for_mms") + ): self.pysdk_model.model_data, env = self._prepare_for_mode() elif overwrite_mode == Mode.LOCAL_CONTAINER: self.mode = self.pysdk_model.mode = Mode.LOCAL_CONTAINER @@ -160,6 +167,13 @@ def _js_builder_deploy_wrapper(self, *args, **kwargs) -> Type[PredictorBase]: dependencies=self.dependencies, model_data=self.pysdk_model.model_data, ) + elif not hasattr(self, "prepared_for_mms"): + self.js_model_config, self.prepared_for_mms = prepare_mms_js_resources( + model_path=self.model_path, + js_id=self.model, + dependencies=self.dependencies, + model_data=self.pysdk_model.model_data, + ) self._prepare_for_mode() env = {} @@ -179,6 +193,10 @@ def _js_builder_deploy_wrapper(self, *args, **kwargs) -> Type[PredictorBase]: predictor = TgiLocalModePredictor( self.modes[str(Mode.LOCAL_CONTAINER)], serializer, deserializer ) + elif self.model_server == ModelServer.MMS: + predictor = TransformersLocalModePredictor( + self.modes[str(Mode.LOCAL_CONTAINER)], serializer, deserializer + ) ram_usage_before = _get_ram_usage_mb() self.modes[str(Mode.LOCAL_CONTAINER)].create_server( @@ -254,6 +272,24 @@ def _build_for_tgi_jumpstart(self): self.pysdk_model.env.update(env) + def _build_for_mms_jumpstart(self): + """Placeholder docstring""" + + env = {} + if self.mode == Mode.LOCAL_CONTAINER: + if not hasattr(self, "prepared_for_mms"): + self.js_model_config, self.prepared_for_mms = prepare_mms_js_resources( + model_path=self.model_path, + js_id=self.model, + dependencies=self.dependencies, + model_data=self.pysdk_model.model_data, + ) + self._prepare_for_mode() + elif self.mode == Mode.SAGEMAKER_ENDPOINT and hasattr(self, "prepared_for_mms"): + self.pysdk_model.model_data, env = self._prepare_for_mode() + + self.pysdk_model.env.update(env) + def _tune_for_js(self, sharded_supported: bool, max_tuning_duration: int = 1800): """Tune for Jumpstart Models in Local Mode. @@ -264,7 +300,7 @@ def _tune_for_js(self, sharded_supported: bool, max_tuning_duration: int = 1800) returns: Tuned Model. """ - if self.mode != Mode.LOCAL_CONTAINER: + if self.mode == Mode.SAGEMAKER_ENDPOINT: logger.warning( "Tuning is only a %s capability. Returning original model.", Mode.LOCAL_CONTAINER ) @@ -485,47 +521,58 @@ def _build_for_jumpstart(self): self.secret_key = None self.jumpstart = True - self.pysdk_model = self._create_pre_trained_js_model() + pysdk_model = self._create_pre_trained_js_model() + image_uri = pysdk_model.image_uri - logger.info( - "JumpStart ID %s is packaged with Image URI: %s", self.model, self.pysdk_model.image_uri - ) + logger.info("JumpStart ID %s is packaged with Image URI: %s", self.model, image_uri) - if self._is_gated_model() and self.mode != Mode.SAGEMAKER_ENDPOINT: + if self._is_gated_model(pysdk_model) and self.mode != Mode.SAGEMAKER_ENDPOINT: raise ValueError( "JumpStart Gated Models are only supported in SAGEMAKER_ENDPOINT mode." ) - if "djl-inference" in self.pysdk_model.image_uri: + if "djl-inference" in image_uri: logger.info("Building for DJL JumpStart Model ID...") self.model_server = ModelServer.DJL_SERVING + self.pysdk_model = pysdk_model self.image_uri = self.pysdk_model.image_uri self._build_for_djl_jumpstart() self.pysdk_model.tune = self.tune_for_djl_jumpstart - elif "tgi-inference" in self.pysdk_model.image_uri: + elif "tgi-inference" in image_uri: logger.info("Building for TGI JumpStart Model ID...") self.model_server = ModelServer.TGI + self.pysdk_model = pysdk_model self.image_uri = self.pysdk_model.image_uri self._build_for_tgi_jumpstart() self.pysdk_model.tune = self.tune_for_tgi_jumpstart - else: + elif "huggingface-pytorch-inference:" in image_uri: + logger.info("Building for MMS JumpStart Model ID...") + self.model_server = ModelServer.MMS + self.pysdk_model = pysdk_model + self.image_uri = self.pysdk_model.image_uri + + self._build_for_mms_jumpstart() + elif self.mode != Mode.SAGEMAKER_ENDPOINT: raise ValueError( - "JumpStart Model ID was not packaged with djl-inference or tgi-inference container." + "JumpStart Model ID was not packaged " + "with djl-inference, tgi-inference, or mms-inference container." ) return self.pysdk_model - def _is_gated_model(self) -> bool: + def _is_gated_model(self, model) -> bool: """Determine if ``this`` Model is Gated + Args: + model (Model): Jumpstart Model Returns: bool: ``True`` if ``this`` Model is Gated """ - s3_uri = self.pysdk_model.model_data + s3_uri = model.model_data if isinstance(s3_uri, dict): s3_uri = s3_uri.get("S3DataSource").get("S3Uri") diff --git a/src/sagemaker/serve/builder/model_builder.py b/src/sagemaker/serve/builder/model_builder.py index 42a2b994a8..44bc46b00b 100644 --- a/src/sagemaker/serve/builder/model_builder.py +++ b/src/sagemaker/serve/builder/model_builder.py @@ -36,6 +36,7 @@ from sagemaker.serve.detector.pickler import save_pkl, save_xgboost from sagemaker.serve.builder.serve_settings import _ServeSettings from sagemaker.serve.builder.djl_builder import DJL +from sagemaker.serve.builder.tei_builder import TEI from sagemaker.serve.builder.tgi_builder import TGI from sagemaker.serve.builder.jumpstart_builder import JumpStart from sagemaker.serve.builder.transformers_builder import Transformers @@ -95,9 +96,9 @@ } -# pylint: disable=attribute-defined-outside-init, disable=E1101, disable=R0901 +# pylint: disable=attribute-defined-outside-init, disable=E1101, disable=R0901, disable=R1705 @dataclass -class ModelBuilder(Triton, DJL, JumpStart, TGI, Transformers, TensorflowServing): +class ModelBuilder(Triton, DJL, JumpStart, TGI, Transformers, TensorflowServing, TEI): """Class that builds a deployable model. Args: @@ -168,7 +169,7 @@ class ModelBuilder(Triton, DJL, JumpStart, TGI, Transformers, TensorflowServing) in order for model builder to build the artifacts correctly (according to the model server). Possible values for this argument are ``TORCHSERVE``, ``MMS``, ``TENSORFLOW_SERVING``, ``DJL_SERVING``, - ``TRITON``, and``TGI``. + ``TRITON``, ``TGI``, and ``TEI``. model_metadata (Optional[Dict[str, Any]): Dictionary used to override model metadata. Currently, ``HF_TASK`` is overridable for HuggingFace model. HF_TASK should be set for new models without task metadata in the Hub, adding unsupported task types will throw @@ -753,7 +754,7 @@ def build( # pylint: disable=R0911 model_task = self.model_metadata.get("HF_TASK") if self._is_jumpstart_model_id(): return self._build_for_jumpstart() - if self._is_djl(): # pylint: disable=R1705 + if self._is_djl(): return self._build_for_djl() else: hf_model_md = get_huggingface_model_metadata( @@ -764,8 +765,10 @@ def build( # pylint: disable=R0911 model_task = hf_model_md.get("pipeline_tag") if self.schema_builder is None and model_task is not None: self._hf_schema_builder_init(model_task) - if model_task == "text-generation": # pylint: disable=R1705 + if model_task == "text-generation": return self._build_for_tgi() + if model_task == "sentence-similarity": + return self._build_for_tei() elif self._can_fit_on_single_gpu(): return self._build_for_transformers() elif ( diff --git a/src/sagemaker/serve/builder/tei_builder.py b/src/sagemaker/serve/builder/tei_builder.py new file mode 100644 index 0000000000..6aba3c9da2 --- /dev/null +++ b/src/sagemaker/serve/builder/tei_builder.py @@ -0,0 +1,224 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Holds mixin logic to support deployment of Model ID""" +from __future__ import absolute_import +import logging +from typing import Type +from abc import ABC, abstractmethod + +from sagemaker import image_uris +from sagemaker.model import Model +from sagemaker.djl_inference.model import _get_model_config_properties_from_hf + +from sagemaker.huggingface import HuggingFaceModel +from sagemaker.serve.utils.local_hardware import ( + _get_nb_instance, +) +from sagemaker.serve.model_server.tgi.prepare import _create_dir_structure +from sagemaker.serve.utils.predictors import TeiLocalModePredictor +from sagemaker.serve.utils.types import ModelServer +from sagemaker.serve.mode.function_pointers import Mode +from sagemaker.serve.utils.telemetry_logger import _capture_telemetry +from sagemaker.base_predictor import PredictorBase + +logger = logging.getLogger(__name__) + +_CODE_FOLDER = "code" + + +class TEI(ABC): + """TEI build logic for ModelBuilder()""" + + def __init__(self): + self.model = None + self.serve_settings = None + self.sagemaker_session = None + self.model_path = None + self.dependencies = None + self.modes = None + self.mode = None + self.model_server = None + self.image_uri = None + self._is_custom_image_uri = False + self.image_config = None + self.vpc_config = None + self._original_deploy = None + self.hf_model_config = None + self._default_tensor_parallel_degree = None + self._default_data_type = None + self._default_max_tokens = None + self.pysdk_model = None + self.schema_builder = None + self.env_vars = None + self.nb_instance_type = None + self.ram_usage_model_load = None + self.secret_key = None + self.jumpstart = None + self.role_arn = None + + @abstractmethod + def _prepare_for_mode(self): + """Placeholder docstring""" + + @abstractmethod + def _get_client_translators(self): + """Placeholder docstring""" + + def _set_to_tei(self): + """Placeholder docstring""" + if self.model_server != ModelServer.TEI: + messaging = ( + "HuggingFace Model ID support on model server: " + f"{self.model_server} is not currently supported. " + f"Defaulting to {ModelServer.TEI}" + ) + logger.warning(messaging) + self.model_server = ModelServer.TEI + + def _create_tei_model(self, **kwargs) -> Type[Model]: + """Placeholder docstring""" + if self.nb_instance_type and "instance_type" not in kwargs: + kwargs.update({"instance_type": self.nb_instance_type}) + + if not self.image_uri: + self.image_uri = image_uris.retrieve( + "huggingface-tei", + image_scope="inference", + instance_type=kwargs.get("instance_type"), + region=self.sagemaker_session.boto_region_name, + ) + + pysdk_model = HuggingFaceModel( + image_uri=self.image_uri, + image_config=self.image_config, + vpc_config=self.vpc_config, + env=self.env_vars, + role=self.role_arn, + sagemaker_session=self.sagemaker_session, + ) + + logger.info("Detected %s. Proceeding with the the deployment.", self.image_uri) + + self._original_deploy = pysdk_model.deploy + pysdk_model.deploy = self._tei_model_builder_deploy_wrapper + return pysdk_model + + @_capture_telemetry("tei.deploy") + def _tei_model_builder_deploy_wrapper(self, *args, **kwargs) -> Type[PredictorBase]: + """Placeholder docstring""" + timeout = kwargs.get("model_data_download_timeout") + if timeout: + self.pysdk_model.env.update({"MODEL_LOADING_TIMEOUT": str(timeout)}) + + if "mode" in kwargs and kwargs.get("mode") != self.mode: + overwrite_mode = kwargs.get("mode") + # mode overwritten by customer during model.deploy() + logger.warning( + "Deploying in %s Mode, overriding existing configurations set for %s mode", + overwrite_mode, + self.mode, + ) + + if overwrite_mode == Mode.SAGEMAKER_ENDPOINT: + self.mode = self.pysdk_model.mode = Mode.SAGEMAKER_ENDPOINT + elif overwrite_mode == Mode.LOCAL_CONTAINER: + self._prepare_for_mode() + self.mode = self.pysdk_model.mode = Mode.LOCAL_CONTAINER + else: + raise ValueError("Mode %s is not supported!" % overwrite_mode) + + serializer = self.schema_builder.input_serializer + deserializer = self.schema_builder._output_deserializer + if self.mode == Mode.LOCAL_CONTAINER: + timeout = kwargs.get("model_data_download_timeout") + + predictor = TeiLocalModePredictor( + self.modes[str(Mode.LOCAL_CONTAINER)], serializer, deserializer + ) + + self.modes[str(Mode.LOCAL_CONTAINER)].create_server( + self.image_uri, + timeout if timeout else 1800, + None, + predictor, + self.pysdk_model.env, + jumpstart=False, + ) + + return predictor + + if "mode" in kwargs: + del kwargs["mode"] + if "role" in kwargs: + self.pysdk_model.role = kwargs.get("role") + del kwargs["role"] + + # set model_data to uncompressed s3 dict + self.pysdk_model.model_data, env_vars = self._prepare_for_mode() + self.env_vars.update(env_vars) + self.pysdk_model.env.update(self.env_vars) + + # if the weights have been cached via local container mode -> set to offline + if str(Mode.LOCAL_CONTAINER) in self.modes: + self.pysdk_model.env.update({"TRANSFORMERS_OFFLINE": "1"}) + else: + # if has not been built for local container we must use cache + # that hosting has write access to. + self.pysdk_model.env["TRANSFORMERS_CACHE"] = "/tmp" + self.pysdk_model.env["HUGGINGFACE_HUB_CACHE"] = "/tmp" + + if "endpoint_logging" not in kwargs: + kwargs["endpoint_logging"] = True + + if self.nb_instance_type and "instance_type" not in kwargs: + kwargs.update({"instance_type": self.nb_instance_type}) + elif not self.nb_instance_type and "instance_type" not in kwargs: + raise ValueError( + "Instance type must be provided when deploying " "to SageMaker Endpoint mode." + ) + + if "initial_instance_count" not in kwargs: + kwargs.update({"initial_instance_count": 1}) + + predictor = self._original_deploy(*args, **kwargs) + + predictor.serializer = serializer + predictor.deserializer = deserializer + return predictor + + def _build_for_hf_tei(self): + """Placeholder docstring""" + self.nb_instance_type = _get_nb_instance() + + _create_dir_structure(self.model_path) + if not hasattr(self, "pysdk_model"): + self.env_vars.update({"HF_MODEL_ID": self.model}) + self.hf_model_config = _get_model_config_properties_from_hf( + self.model, self.env_vars.get("HUGGING_FACE_HUB_TOKEN") + ) + + self.pysdk_model = self._create_tei_model() + + if self.mode == Mode.LOCAL_CONTAINER: + self._prepare_for_mode() + + return self.pysdk_model + + def _build_for_tei(self): + """Placeholder docstring""" + self.secret_key = None + + self._set_to_tei() + + self.pysdk_model = self._build_for_hf_tei() + return self.pysdk_model diff --git a/src/sagemaker/serve/builder/transformers_builder.py b/src/sagemaker/serve/builder/transformers_builder.py index ead9b7425f..f84d8f868d 100644 --- a/src/sagemaker/serve/builder/transformers_builder.py +++ b/src/sagemaker/serve/builder/transformers_builder.py @@ -78,6 +78,25 @@ def _prepare_for_mode(self): """Abstract method""" def _create_transformers_model(self) -> Type[Model]: + """Initializes HF model with or without image_uri""" + if self.image_uri is None: + pysdk_model = self._get_hf_metadata_create_model() + else: + pysdk_model = HuggingFaceModel( + image_uri=self.image_uri, + vpc_config=self.vpc_config, + env=self.env_vars, + role=self.role_arn, + sagemaker_session=self.sagemaker_session, + ) + + logger.info("Detected %s. Proceeding with the the deployment.", self.image_uri) + + self._original_deploy = pysdk_model.deploy + pysdk_model.deploy = self._transformers_model_builder_deploy_wrapper + return pysdk_model + + def _get_hf_metadata_create_model(self) -> Type[Model]: """Initializes the model after fetching image 1. Get the metadata for deciding framework @@ -141,13 +160,12 @@ def _create_transformers_model(self) -> Type[Model]: self.sagemaker_session.boto_region_name, self.instance_type ) - logger.info("Detected %s. Proceeding with the the deployment.", self.image_uri) + if pysdk_model is None or self.image_uri is None: + raise ValueError("PySDK model unable to be created, try overriding image_uri") if not pysdk_model.image_uri: pysdk_model.image_uri = self.image_uri - self._original_deploy = pysdk_model.deploy - pysdk_model.deploy = self._transformers_model_builder_deploy_wrapper return pysdk_model @_capture_telemetry("transformers.deploy") diff --git a/src/sagemaker/serve/mode/local_container_mode.py b/src/sagemaker/serve/mode/local_container_mode.py index f940e2959c..f040c61c1d 100644 --- a/src/sagemaker/serve/mode/local_container_mode.py +++ b/src/sagemaker/serve/mode/local_container_mode.py @@ -21,6 +21,7 @@ from sagemaker.serve.model_server.djl_serving.server import LocalDJLServing from sagemaker.serve.model_server.triton.server import LocalTritonServer from sagemaker.serve.model_server.tgi.server import LocalTgiServing +from sagemaker.serve.model_server.tei.server import LocalTeiServing from sagemaker.serve.model_server.multi_model_server.server import LocalMultiModelServer from sagemaker.session import Session @@ -69,6 +70,7 @@ def __init__( self.container = None self.secret_key = None self._ping_container = None + self._invoke_serving = None def load(self, model_path: str = None): """Placeholder docstring""" @@ -156,6 +158,19 @@ def create_server( env_vars=env_vars if env_vars else self.env_vars, ) self._ping_container = self._tensorflow_serving_deep_ping + elif self.model_server == ModelServer.TEI: + tei_serving = LocalTeiServing() + tei_serving._start_tei_serving( + client=self.client, + image=image, + model_path=model_path if model_path else self.model_path, + secret_key=secret_key, + env_vars=env_vars if env_vars else self.env_vars, + ) + tei_serving.schema_builder = self.schema_builder + self.container = tei_serving.container + self._ping_container = tei_serving._tei_deep_ping + self._invoke_serving = tei_serving._invoke_tei_serving # allow some time for container to be ready time.sleep(10) diff --git a/src/sagemaker/serve/mode/sagemaker_endpoint_mode.py b/src/sagemaker/serve/mode/sagemaker_endpoint_mode.py index 24acfc6a2f..b8f1d0529b 100644 --- a/src/sagemaker/serve/mode/sagemaker_endpoint_mode.py +++ b/src/sagemaker/serve/mode/sagemaker_endpoint_mode.py @@ -6,6 +6,7 @@ import logging from typing import Type +from sagemaker.serve.model_server.tei.server import SageMakerTeiServing from sagemaker.serve.model_server.tensorflow_serving.server import SageMakerTensorflowServing from sagemaker.session import Session from sagemaker.serve.utils.types import ModelServer @@ -37,6 +38,8 @@ def __init__(self, inference_spec: Type[InferenceSpec], model_server: ModelServe self.inference_spec = inference_spec self.model_server = model_server + self._tei_serving = SageMakerTeiServing() + def load(self, model_path: str): """Placeholder docstring""" path = Path(model_path) @@ -66,8 +69,9 @@ def prepare( + "session to be created or supply `sagemaker_session` into @serve.invoke." ) from e + upload_artifacts = None if self.model_server == ModelServer.TORCHSERVE: - return self._upload_torchserve_artifacts( + upload_artifacts = self._upload_torchserve_artifacts( model_path=model_path, sagemaker_session=sagemaker_session, secret_key=secret_key, @@ -76,7 +80,7 @@ def prepare( ) if self.model_server == ModelServer.TRITON: - return self._upload_triton_artifacts( + upload_artifacts = self._upload_triton_artifacts( model_path=model_path, sagemaker_session=sagemaker_session, secret_key=secret_key, @@ -85,7 +89,7 @@ def prepare( ) if self.model_server == ModelServer.DJL_SERVING: - return self._upload_djl_artifacts( + upload_artifacts = self._upload_djl_artifacts( model_path=model_path, sagemaker_session=sagemaker_session, s3_model_data_url=s3_model_data_url, @@ -93,7 +97,7 @@ def prepare( ) if self.model_server == ModelServer.TGI: - return self._upload_tgi_artifacts( + upload_artifacts = self._upload_tgi_artifacts( model_path=model_path, sagemaker_session=sagemaker_session, s3_model_data_url=s3_model_data_url, @@ -102,7 +106,7 @@ def prepare( ) if self.model_server == ModelServer.MMS: - return self._upload_server_artifacts( + upload_artifacts = self._upload_server_artifacts( model_path=model_path, sagemaker_session=sagemaker_session, s3_model_data_url=s3_model_data_url, @@ -110,7 +114,7 @@ def prepare( ) if self.model_server == ModelServer.TENSORFLOW_SERVING: - return self._upload_tensorflow_serving_artifacts( + upload_artifacts = self._upload_tensorflow_serving_artifacts( model_path=model_path, sagemaker_session=sagemaker_session, secret_key=secret_key, @@ -118,4 +122,15 @@ def prepare( image=image, ) + if self.model_server == ModelServer.TEI: + upload_artifacts = self._tei_serving._upload_tei_artifacts( + model_path=model_path, + sagemaker_session=sagemaker_session, + s3_model_data_url=s3_model_data_url, + image=image, + ) + + if upload_artifacts: + return upload_artifacts + raise ValueError("%s model server is not supported" % self.model_server) diff --git a/src/sagemaker/serve/model_server/multi_model_server/prepare.py b/src/sagemaker/serve/model_server/multi_model_server/prepare.py index 7a16cc0a43..7059d9026d 100644 --- a/src/sagemaker/serve/model_server/multi_model_server/prepare.py +++ b/src/sagemaker/serve/model_server/multi_model_server/prepare.py @@ -15,7 +15,9 @@ from __future__ import absolute_import import logging from pathlib import Path +from typing import List +from sagemaker.serve.model_server.tgi.prepare import _copy_jumpstart_artifacts from sagemaker.serve.utils.local_hardware import _check_disk_space, _check_docker_disk_usage logger = logging.getLogger(__name__) @@ -36,3 +38,28 @@ def _create_dir_structure(model_path: str) -> tuple: _check_docker_disk_usage() return model_path, code_dir + + +def prepare_mms_js_resources( + model_path: str, + js_id: str, + shared_libs: List[str] = None, + dependencies: str = None, + model_data: str = None, +) -> tuple: + """Prepare serving when a JumpStart model id is given + + Args: + model_path (str) : Argument + js_id (str): Argument + shared_libs (List[]) : Argument + dependencies (str) : Argument + model_data (str) : Argument + + Returns: + ( str ) : + + """ + model_path, code_dir = _create_dir_structure(model_path) + + return _copy_jumpstart_artifacts(model_data, js_id, code_dir) diff --git a/src/sagemaker/serve/model_server/tei/__init__.py b/src/sagemaker/serve/model_server/tei/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/sagemaker/serve/model_server/tei/server.py b/src/sagemaker/serve/model_server/tei/server.py new file mode 100644 index 0000000000..67fca0e847 --- /dev/null +++ b/src/sagemaker/serve/model_server/tei/server.py @@ -0,0 +1,160 @@ +"""Module for Local TEI Serving""" + +from __future__ import absolute_import + +import requests +import logging +from pathlib import Path +from docker.types import DeviceRequest +from sagemaker import Session, fw_utils +from sagemaker.serve.utils.exceptions import LocalModelInvocationException +from sagemaker.base_predictor import PredictorBase +from sagemaker.s3_utils import determine_bucket_and_prefix, parse_s3_url, s3_path_join +from sagemaker.s3 import S3Uploader +from sagemaker.local.utils import get_docker_host + + +MODE_DIR_BINDING = "/opt/ml/model/" +_SHM_SIZE = "2G" +_DEFAULT_ENV_VARS = { + "TRANSFORMERS_CACHE": "/opt/ml/model/", + "HUGGINGFACE_HUB_CACHE": "/opt/ml/model/", +} + +logger = logging.getLogger(__name__) + + +class LocalTeiServing: + """LocalTeiServing class""" + + def _start_tei_serving( + self, client: object, image: str, model_path: str, secret_key: str, env_vars: dict + ): + """Starts a local tei serving container. + + Args: + client: Docker client + image: Image to use + model_path: Path to the model + secret_key: Secret key to use for authentication + env_vars: Environment variables to set + """ + if env_vars and secret_key: + env_vars["SAGEMAKER_SERVE_SECRET_KEY"] = secret_key + + self.container = client.containers.run( + image, + shm_size=_SHM_SIZE, + device_requests=[DeviceRequest(count=-1, capabilities=[["gpu"]])], + network_mode="host", + detach=True, + auto_remove=True, + volumes={ + Path(model_path).joinpath("code"): { + "bind": MODE_DIR_BINDING, + "mode": "rw", + }, + }, + environment=_update_env_vars(env_vars), + ) + + def _invoke_tei_serving(self, request: object, content_type: str, accept: str): + """Invokes a local tei serving container. + + Args: + request: Request to send + content_type: Content type to use + accept: Accept to use + """ + try: + response = requests.post( + f"http://{get_docker_host()}:8080/invocations", + data=request, + headers={"Content-Type": content_type, "Accept": accept}, + timeout=600, + ) + response.raise_for_status() + return response.content + except Exception as e: + raise Exception("Unable to send request to the local container server") from e + + def _tei_deep_ping(self, predictor: PredictorBase): + """Checks if the local tei serving container is up and running. + + If the container is not up and running, it will raise an exception. + """ + response = None + try: + response = predictor.predict(self.schema_builder.sample_input) + return (True, response) + # pylint: disable=broad-except + except Exception as e: + if "422 Client Error: Unprocessable Entity for url" in str(e): + raise LocalModelInvocationException(str(e)) + return (False, response) + + return (True, response) + + +class SageMakerTeiServing: + """SageMakerTeiServing class""" + + def _upload_tei_artifacts( + self, + model_path: str, + sagemaker_session: Session, + s3_model_data_url: str = None, + image: str = None, + env_vars: dict = None, + ): + """Uploads the model artifacts to S3. + + Args: + model_path: Path to the model + sagemaker_session: SageMaker session + s3_model_data_url: S3 model data URL + image: Image to use + env_vars: Environment variables to set + """ + if s3_model_data_url: + bucket, key_prefix = parse_s3_url(url=s3_model_data_url) + else: + bucket, key_prefix = None, None + + code_key_prefix = fw_utils.model_code_key_prefix(key_prefix, None, image) + + bucket, code_key_prefix = determine_bucket_and_prefix( + bucket=bucket, key_prefix=code_key_prefix, sagemaker_session=sagemaker_session + ) + + code_dir = Path(model_path).joinpath("code") + + s3_location = s3_path_join("s3://", bucket, code_key_prefix, "code") + + logger.debug("Uploading TEI Model Resources uncompressed to: %s", s3_location) + + model_data_url = S3Uploader.upload( + str(code_dir), + s3_location, + None, + sagemaker_session, + ) + + model_data = { + "S3DataSource": { + "CompressionType": "None", + "S3DataType": "S3Prefix", + "S3Uri": model_data_url + "/", + } + } + + return (model_data, _update_env_vars(env_vars)) + + +def _update_env_vars(env_vars: dict) -> dict: + """Placeholder docstring""" + updated_env_vars = {} + updated_env_vars.update(_DEFAULT_ENV_VARS) + if env_vars: + updated_env_vars.update(env_vars) + return updated_env_vars diff --git a/src/sagemaker/serve/utils/predictors.py b/src/sagemaker/serve/utils/predictors.py index 866167c2c6..25a995eb48 100644 --- a/src/sagemaker/serve/utils/predictors.py +++ b/src/sagemaker/serve/utils/predictors.py @@ -209,6 +209,49 @@ def delete_predictor(self): self._mode_obj.destroy_server() +class TeiLocalModePredictor(PredictorBase): + """Lightweight Tei predictor for local deployment in IN_PROCESS and LOCAL_CONTAINER modes""" + + def __init__( + self, + mode_obj: Type[LocalContainerMode], + serializer=JSONSerializer(), + deserializer=JSONDeserializer(), + ): + self._mode_obj = mode_obj + self.serializer = serializer + self.deserializer = deserializer + + def predict(self, data): + """Placeholder docstring""" + return [ + self.deserializer.deserialize( + io.BytesIO( + self._mode_obj._invoke_serving( + self.serializer.serialize(data), + self.content_type, + self.deserializer.ACCEPT[0], + ) + ), + self.content_type, + ) + ] + + @property + def content_type(self): + """The MIME type of the data sent to the inference endpoint.""" + return self.serializer.CONTENT_TYPE + + @property + def accept(self): + """The content type(s) that are expected from the inference endpoint.""" + return self.deserializer.ACCEPT + + def delete_predictor(self): + """Shut down and remove the container that you created in LOCAL_CONTAINER mode""" + self._mode_obj.destroy_server() + + class TensorflowServingLocalPredictor(PredictorBase): """Lightweight predictor for local deployment in LOCAL_CONTAINER modes""" diff --git a/src/sagemaker/serve/utils/telemetry_logger.py b/src/sagemaker/serve/utils/telemetry_logger.py index 8983a4b5c9..342a88c945 100644 --- a/src/sagemaker/serve/utils/telemetry_logger.py +++ b/src/sagemaker/serve/utils/telemetry_logger.py @@ -58,6 +58,15 @@ str(ModelServer.DJL_SERVING): 4, str(ModelServer.TRITON): 5, str(ModelServer.TGI): 6, + str(ModelServer.TEI): 7, +} + +MLFLOW_MODEL_PATH_CODE = { + MLFLOW_LOCAL_PATH: 1, + MLFLOW_S3_PATH: 2, + MLFLOW_MODEL_PACKAGE_PATH: 3, + MLFLOW_RUN_ID: 4, + MLFLOW_REGISTRY_PATH: 5, } MLFLOW_MODEL_PATH_CODE = { diff --git a/src/sagemaker/serve/utils/types.py b/src/sagemaker/serve/utils/types.py index 661093f249..3ac80aa7ea 100644 --- a/src/sagemaker/serve/utils/types.py +++ b/src/sagemaker/serve/utils/types.py @@ -18,6 +18,7 @@ def __str__(self): DJL_SERVING = 4 TRITON = 5 TGI = 6 + TEI = 7 class _DjlEngine(Enum): diff --git a/src/sagemaker/utils.py b/src/sagemaker/utils.py index 89db48ffd8..a70ba9eb98 100644 --- a/src/sagemaker/utils.py +++ b/src/sagemaker/utils.py @@ -47,6 +47,7 @@ _log_sagemaker_config_single_substitution, _log_sagemaker_config_merge, ) +from sagemaker.enums import RoutingStrategy from sagemaker.session_settings import SessionSettings from sagemaker.workflow import is_pipeline_variable, is_pipeline_parameter_string from sagemaker.workflow.entities import PipelineVariable @@ -1696,6 +1697,36 @@ def deep_override_dict( return unflatten_dict(flattened_dict1) if flattened_dict1 else {} +def _resolve_routing_config(routing_config: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]: + """Resolve Routing Config + + Args: + routing_config (Optional[Dict[str, Any]]): The routing config. + + Returns: + Optional[Dict[str, Any]]: The resolved routing config. + + Raises: + ValueError: If the RoutingStrategy is invalid. + """ + + if routing_config: + routing_strategy = routing_config.get("RoutingStrategy", None) + if routing_strategy: + if isinstance(routing_strategy, RoutingStrategy): + return {"RoutingStrategy": routing_strategy.name} + if isinstance(routing_strategy, str) and ( + routing_strategy.upper() == RoutingStrategy.RANDOM.name + or routing_strategy.upper() == RoutingStrategy.LEAST_OUTSTANDING_REQUESTS.name + ): + return {"RoutingStrategy": routing_strategy.upper()} + raise ValueError( + "RoutingStrategy must be either RoutingStrategy.RANDOM " + "or RoutingStrategy.LEAST_OUTSTANDING_REQUESTS" + ) + return None + + @lru_cache def get_instance_rate_per_hour( instance_type: str, diff --git a/tests/conftest.py b/tests/conftest.py index 0309781e7b..7bab05dfb3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -253,7 +253,9 @@ def mxnet_eia_latest_py_version(): @pytest.fixture(scope="module", params=["py2", "py3"]) def pytorch_training_py_version(pytorch_training_version, request): - if Version(pytorch_training_version) >= Version("2.0"): + if Version(pytorch_training_version) >= Version("2.3"): + return "py311" + elif Version(pytorch_training_version) >= Version("2.0"): return "py310" elif Version(pytorch_training_version) >= Version("1.13"): return "py39" diff --git a/tests/data/serve_resources/mlflow/pytorch/requirements.txt b/tests/data/serve_resources/mlflow/pytorch/requirements.txt index a10494de69..895e2173bf 100644 --- a/tests/data/serve_resources/mlflow/pytorch/requirements.txt +++ b/tests/data/serve_resources/mlflow/pytorch/requirements.txt @@ -1,4 +1,4 @@ -mlflow==2.10.2 +mlflow==2.12.1 astunparse==1.6.3 cffi==1.16.0 cloudpickle==2.2.1 @@ -10,7 +10,7 @@ opt-einsum==3.3.0 packaging==21.3 pandas==2.2.1 pyyaml==6.0.1 -requests==2.31.0 +requests==2.32.2 torch==2.0.1 torchvision==0.15.2 tqdm==4.66.3 diff --git a/tests/data/serve_resources/mlflow/tensorflow/requirements.txt b/tests/data/serve_resources/mlflow/tensorflow/requirements.txt index 2ff55b8e87..d4ff5b4782 100644 --- a/tests/data/serve_resources/mlflow/tensorflow/requirements.txt +++ b/tests/data/serve_resources/mlflow/tensorflow/requirements.txt @@ -1,4 +1,4 @@ -mlflow==2.11.1 +mlflow==2.12.1 cloudpickle==2.2.1 numpy==1.26.4 tensorflow==2.16.1 diff --git a/tests/data/serve_resources/mlflow/xgboost/requirements.txt b/tests/data/serve_resources/mlflow/xgboost/requirements.txt index 8150c9fedf..18d687aec6 100644 --- a/tests/data/serve_resources/mlflow/xgboost/requirements.txt +++ b/tests/data/serve_resources/mlflow/xgboost/requirements.txt @@ -1,4 +1,4 @@ -mlflow==2.11.1 +mlflow==2.12.1 lz4==4.3.2 numpy==1.24.4 pandas==2.0.3 diff --git a/tests/integ/sagemaker/conftest.py b/tests/integ/sagemaker/conftest.py index 2640f51515..043b0c703e 100644 --- a/tests/integ/sagemaker/conftest.py +++ b/tests/integ/sagemaker/conftest.py @@ -278,13 +278,14 @@ def _generate_sagemaker_sdk_tar(destination_folder): """ Run setup.py sdist to generate the PySDK tar file """ - command = f"python3 setup.py egg_info --egg-base {destination_folder} sdist -d {destination_folder} -k" + command = f"python3 setup.py egg_info --egg-base {destination_folder} sdist -d {destination_folder} -k --verbose" print(f"Running command: {command}") result = subprocess.run(command, shell=True, check=True, capture_output=True) if result.returncode != 0: print(f"Command failed with return code: {result.returncode}") - print(f"Standard output: {result.stdout.decode()}") - print(f"Standard error: {result.stderr.decode()}") + + print(f"Standard output: {result.stdout.decode()}") + print(f"Standard error: {result.stderr.decode()}") destination_folder_contents = os.listdir(destination_folder) source_archive = [file for file in destination_folder_contents if file.endswith("tar.gz")][0] diff --git a/tests/integ/sagemaker/serve/test_serve_js_happy.py b/tests/integ/sagemaker/serve/test_serve_js_happy.py index 7835c8ae3c..ad0527fcc0 100644 --- a/tests/integ/sagemaker/serve/test_serve_js_happy.py +++ b/tests/integ/sagemaker/serve/test_serve_js_happy.py @@ -34,6 +34,14 @@ JS_MODEL_ID = "huggingface-textgeneration1-gpt-neo-125m-fp16" ROLE_NAME = "SageMakerRole" +SAMPLE_MMS_PROMPT = [ + "How cute your dog is!", + "Your dog is so cute.", + "The mitochondria is the powerhouse of the cell.", +] +SAMPLE_MMS_RESPONSE = {"embedding": []} +JS_MMS_MODEL_ID = "huggingface-sentencesimilarity-bge-m3" + @pytest.fixture def happy_model_builder(sagemaker_session): @@ -46,6 +54,17 @@ def happy_model_builder(sagemaker_session): ) +@pytest.fixture +def happy_mms_model_builder(sagemaker_session): + iam_client = sagemaker_session.boto_session.client("iam") + return ModelBuilder( + model=JS_MMS_MODEL_ID, + schema_builder=SchemaBuilder(SAMPLE_MMS_PROMPT, SAMPLE_MMS_RESPONSE), + role_arn=iam_client.get_role(RoleName=ROLE_NAME)["Role"]["Arn"], + sagemaker_session=sagemaker_session, + ) + + @pytest.mark.skipif( PYTHON_VERSION_IS_NOT_310, reason="The goal of these test are to test the serving components of our feature", @@ -75,3 +94,34 @@ def test_happy_tgi_sagemaker_endpoint(happy_model_builder, gpu_instance_type): ) if caught_ex: raise caught_ex + + +@pytest.mark.skipif( + PYTHON_VERSION_IS_NOT_310, + reason="The goal of these test are to test the serving components of our feature", +) +@pytest.mark.slow_test +def test_happy_mms_sagemaker_endpoint(happy_mms_model_builder, gpu_instance_type): + logger.info("Running in SAGEMAKER_ENDPOINT mode...") + caught_ex = None + model = happy_mms_model_builder.build() + + with timeout(minutes=SERVE_SAGEMAKER_ENDPOINT_TIMEOUT): + try: + logger.info("Deploying and predicting in SAGEMAKER_ENDPOINT mode...") + predictor = model.deploy(instance_type=gpu_instance_type, endpoint_logging=False) + logger.info("Endpoint successfully deployed.") + + updated_sample_input = happy_mms_model_builder.schema_builder.sample_input + + predictor.predict(updated_sample_input) + except Exception as e: + caught_ex = e + finally: + cleanup_model_resources( + sagemaker_session=happy_mms_model_builder.sagemaker_session, + model_name=model.name, + endpoint_name=model.endpoint_name, + ) + if caught_ex: + raise caught_ex diff --git a/tests/integ/sagemaker/serve/test_serve_tei.py b/tests/integ/sagemaker/serve/test_serve_tei.py new file mode 100644 index 0000000000..5cf1a3635c --- /dev/null +++ b/tests/integ/sagemaker/serve/test_serve_tei.py @@ -0,0 +1,87 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import pytest +from sagemaker.serve.builder.schema_builder import SchemaBuilder +from sagemaker.serve.builder.model_builder import ModelBuilder, Mode + +from tests.integ.sagemaker.serve.constants import ( + HF_DIR, + PYTHON_VERSION_IS_NOT_310, + SERVE_SAGEMAKER_ENDPOINT_TIMEOUT, +) + +from tests.integ.timeout import timeout +from tests.integ.utils import cleanup_model_resources +import logging + +logger = logging.getLogger(__name__) + +sample_input = {"inputs": "What is Deep Learning?"} + +loaded_response = [] + + +@pytest.fixture +def model_input(): + return {"inputs": "What is Deep Learning?"} + + +@pytest.fixture +def model_builder_model_schema_builder(): + return ModelBuilder( + model_path=HF_DIR, + model="BAAI/bge-m3", + schema_builder=SchemaBuilder(sample_input, loaded_response), + ) + + +@pytest.fixture +def model_builder(request): + return request.getfixturevalue(request.param) + + +@pytest.mark.skipif( + PYTHON_VERSION_IS_NOT_310, + reason="Testing feature needs latest metadata", +) +@pytest.mark.parametrize("model_builder", ["model_builder_model_schema_builder"], indirect=True) +def test_tei_sagemaker_endpoint(sagemaker_session, model_builder, model_input): + logger.info("Running in SAGEMAKER_ENDPOINT mode...") + caught_ex = None + + iam_client = sagemaker_session.boto_session.client("iam") + role_arn = iam_client.get_role(RoleName="SageMakerRole")["Role"]["Arn"] + + model = model_builder.build( + mode=Mode.SAGEMAKER_ENDPOINT, role_arn=role_arn, sagemaker_session=sagemaker_session + ) + + with timeout(minutes=SERVE_SAGEMAKER_ENDPOINT_TIMEOUT): + try: + logger.info("Deploying and predicting in SAGEMAKER_ENDPOINT mode...") + predictor = model.deploy(instance_type="ml.g5.2xlarge", initial_instance_count=1) + predictor.predict(model_input) + assert predictor is not None + except Exception as e: + caught_ex = e + finally: + cleanup_model_resources( + sagemaker_session=model_builder.sagemaker_session, + model_name=model.name, + endpoint_name=model.endpoint_name, + ) + if caught_ex: + logger.exception(caught_ex) + assert False, f"{caught_ex} was thrown when running tei sagemaker endpoint test" diff --git a/tests/integ/sagemaker/serve/test_serve_transformers.py b/tests/integ/sagemaker/serve/test_serve_transformers.py index 64029f7290..33a1ae6708 100644 --- a/tests/integ/sagemaker/serve/test_serve_transformers.py +++ b/tests/integ/sagemaker/serve/test_serve_transformers.py @@ -127,4 +127,4 @@ def test_pytorch_transformers_sagemaker_endpoint( logger.exception(caught_ex) assert ( False - ), f"{caught_ex} was thrown when running pytorch transformers sagemaker endpoint test" + ), f"{caught_ex} thrown when running pytorch transformers sagemaker endpoint test" diff --git a/tests/unit/sagemaker/image_uris/test_huggingface_llm.py b/tests/unit/sagemaker/image_uris/test_huggingface_llm.py index 5a9662c164..fa10fd24fe 100644 --- a/tests/unit/sagemaker/image_uris/test_huggingface_llm.py +++ b/tests/unit/sagemaker/image_uris/test_huggingface_llm.py @@ -18,6 +18,14 @@ from tests.unit.sagemaker.image_uris import expected_uris, conftest LMI_VERSIONS = ["0.24.0"] +TEI_VERSIONS_MAPPING = { + "gpu": { + "1.2.3": "2.0.1-tei1.2.3-gpu-py310-cu122-ubuntu22.04", + }, + "cpu": { + "1.2.3": "2.0.1-tei1.2.3-cpu-py310-ubuntu22.04", + }, +} HF_VERSIONS_MAPPING = { "gpu": { "0.6.0": "2.0.0-tgi0.6.0-gpu-py39-cu118-ubuntu20.04", @@ -68,6 +76,28 @@ def test_huggingface_uris(load_config): assert expected == uri +@pytest.mark.parametrize( + "load_config", ["huggingface-tei.json", "huggingface-tei-cpu.json"], indirect=True +) +def test_huggingface_tei_uris(load_config): + VERSIONS = load_config["inference"]["versions"] + device = load_config["inference"]["processors"][0] + backend = "huggingface-tei" if device == "gpu" else "huggingface-tei-cpu" + repo = "tei" if device == "gpu" else "tei-cpu" + for version in VERSIONS: + ACCOUNTS = load_config["inference"]["versions"][version]["registries"] + for region in ACCOUNTS.keys(): + uri = get_huggingface_llm_image_uri(backend, region=region, version=version) + expected = expected_uris.huggingface_llm_framework_uri( + repo, + ACCOUNTS[region], + version, + TEI_VERSIONS_MAPPING[device][version], + region=region, + ) + assert expected == uri + + @pytest.mark.parametrize("load_config", ["huggingface-llm.json"], indirect=True) def test_lmi_uris(load_config): VERSIONS = load_config["inference"]["versions"] diff --git a/tests/unit/sagemaker/jumpstart/test_notebook_utils.py b/tests/unit/sagemaker/jumpstart/test_notebook_utils.py index 977853c5be..301afe4d53 100644 --- a/tests/unit/sagemaker/jumpstart/test_notebook_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_notebook_utils.py @@ -1,6 +1,7 @@ from __future__ import absolute_import -import json + import datetime +import json from unittest import TestCase from unittest.mock import Mock, patch diff --git a/tests/unit/sagemaker/model/test_deploy.py b/tests/unit/sagemaker/model/test_deploy.py index 953cbe775c..69ea2c1f56 100644 --- a/tests/unit/sagemaker/model/test_deploy.py +++ b/tests/unit/sagemaker/model/test_deploy.py @@ -125,6 +125,7 @@ def test_deploy(name_from_base, prepare_container_def, production_variant, sagem volume_size=None, model_data_download_timeout=None, container_startup_health_check_timeout=None, + routing_config=None, ) sagemaker_session.create_model.assert_called_with( @@ -184,6 +185,7 @@ def test_deploy_accelerator_type( volume_size=None, model_data_download_timeout=None, container_startup_health_check_timeout=None, + routing_config=None, ) sagemaker_session.endpoint_from_production_variants.assert_called_with( @@ -506,6 +508,7 @@ def test_deploy_serverless_inference(production_variant, create_sagemaker_model, volume_size=None, model_data_download_timeout=None, container_startup_health_check_timeout=None, + routing_config=None, ) sagemaker_session.endpoint_from_production_variants.assert_called_with( @@ -938,6 +941,7 @@ def test_deploy_customized_volume_size_and_timeout( volume_size=volume_size_gb, model_data_download_timeout=model_data_download_timeout_sec, container_startup_health_check_timeout=startup_health_check_timeout_sec, + routing_config=None, ) sagemaker_session.create_model.assert_called_with( @@ -987,6 +991,7 @@ def test_deploy_with_resources(sagemaker_session, name_from_base, production_var volume_size=None, model_data_download_timeout=None, container_startup_health_check_timeout=None, + routing_config=None, ) sagemaker_session.endpoint_from_production_variants.assert_called_with( name=name_from_base(MODEL_NAME), diff --git a/tests/unit/sagemaker/serve/builder/test_js_builder.py b/tests/unit/sagemaker/serve/builder/test_js_builder.py index 4ec96e88e3..e38317067c 100644 --- a/tests/unit/sagemaker/serve/builder/test_js_builder.py +++ b/tests/unit/sagemaker/serve/builder/test_js_builder.py @@ -64,6 +64,10 @@ "123456789712.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-tgi" "-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04" ) +mock_invalid_image_uri = ( + "123456789712.dkr.ecr.us-west-2.amazonaws.com/invalid" + "-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04" +) mock_djl_image_uri = ( "123456789712.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.24.0-neuronx-sdk2.14.1" ) @@ -83,6 +87,88 @@ class TestJumpStartBuilder(unittest.TestCase): + @patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None) + @patch( + "sagemaker.serve.builder.jumpstart_builder.JumpStart._is_jumpstart_model_id", + return_value=True, + ) + @patch( + "sagemaker.serve.builder.jumpstart_builder.JumpStart._create_pre_trained_js_model", + return_value=MagicMock(), + ) + @patch( + "sagemaker.serve.builder.jumpstart_builder.prepare_tgi_js_resources", + return_value=({"model_type": "t5", "n_head": 71}, True), + ) + @patch("sagemaker.serve.builder.jumpstart_builder._get_ram_usage_mb", return_value=1024) + @patch( + "sagemaker.serve.builder.jumpstart_builder._get_nb_instance", return_value="ml.g5.24xlarge" + ) + def test__build_for_jumpstart_value_error( + self, + mock_get_nb_instance, + mock_get_ram_usage_mb, + mock_prepare_for_tgi, + mock_pre_trained_model, + mock_is_jumpstart_model, + mock_telemetry, + ): + builder = ModelBuilder( + model="facebook/invalid", + schema_builder=mock_schema_builder, + mode=Mode.LOCAL_CONTAINER, + ) + + mock_pre_trained_model.return_value.image_uri = mock_invalid_image_uri + + self.assertRaises( + ValueError, + lambda: builder.build(), + ) + + @patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None) + @patch( + "sagemaker.serve.builder.jumpstart_builder.JumpStart._is_jumpstart_model_id", + return_value=True, + ) + @patch( + "sagemaker.serve.builder.jumpstart_builder.JumpStart._create_pre_trained_js_model", + return_value=MagicMock(), + ) + @patch( + "sagemaker.serve.builder.jumpstart_builder.prepare_mms_js_resources", + return_value=({"model_type": "t5", "n_head": 71}, True), + ) + @patch("sagemaker.serve.builder.jumpstart_builder._get_ram_usage_mb", return_value=1024) + @patch( + "sagemaker.serve.builder.jumpstart_builder._get_nb_instance", return_value="ml.g5.24xlarge" + ) + def test__build_for_mms_jumpstart( + self, + mock_get_nb_instance, + mock_get_ram_usage_mb, + mock_prepare_for_mms, + mock_pre_trained_model, + mock_is_jumpstart_model, + mock_telemetry, + ): + builder = ModelBuilder( + model="facebook/galactica-mock-model-id", + schema_builder=mock_schema_builder, + mode=Mode.LOCAL_CONTAINER, + ) + + mock_pre_trained_model.return_value.image_uri = ( + "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface" + "-pytorch-inference:2.1.0-transformers4.37.0-gpu-py310-cu118" + "-ubuntu20.04" + ) + + builder.build() + builder.serve_settings.telemetry_opt_out = True + + mock_prepare_for_mms.assert_called() + @patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None) @patch( "sagemaker.serve.builder.jumpstart_builder.JumpStart._is_jumpstart_model_id", diff --git a/tests/unit/sagemaker/serve/builder/test_model_builder.py b/tests/unit/sagemaker/serve/builder/test_model_builder.py index 3ffbdd7c03..0c06b5ae8e 100644 --- a/tests/unit/sagemaker/serve/builder/test_model_builder.py +++ b/tests/unit/sagemaker/serve/builder/test_model_builder.py @@ -1476,6 +1476,44 @@ def test_text_generation( mock_build_for_tgi.assert_called_once() + @patch("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_tei") + @patch("sagemaker.image_uris.retrieve") + @patch("sagemaker.djl_inference.model.urllib") + @patch("sagemaker.djl_inference.model.json") + @patch("sagemaker.huggingface.llm_utils.urllib") + @patch("sagemaker.huggingface.llm_utils.json") + @patch("sagemaker.model_uris.retrieve") + @patch("sagemaker.serve.builder.model_builder._ServeSettings") + def test_sentence_similarity( + self, + mock_serveSettings, + mock_model_uris_retrieve, + mock_llm_utils_json, + mock_llm_utils_urllib, + mock_model_json, + mock_model_urllib, + mock_image_uris_retrieve, + mock_build_for_tei, + ): + mock_setting_object = mock_serveSettings.return_value + mock_setting_object.role_arn = mock_role_arn + mock_setting_object.s3_model_data_url = mock_s3_model_data_url + + mock_model_uris_retrieve.side_effect = KeyError + mock_llm_utils_json.load.return_value = {"pipeline_tag": "sentence-similarity"} + mock_llm_utils_urllib.request.Request.side_effect = Mock() + + mock_model_json.load.return_value = {"some": "config"} + mock_model_urllib.request.Request.side_effect = Mock() + mock_build_for_tei.side_effect = Mock() + + mock_image_uris_retrieve.return_value = "https://some-image-uri" + + model_builder = ModelBuilder(model="bloom-560m", schema_builder=schema_builder) + model_builder.build(sagemaker_session=mock_session) + + mock_build_for_tei.assert_called_once() + @patch("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_transformers", Mock()) @patch("sagemaker.serve.builder.model_builder.ModelBuilder._try_fetch_gpu_info") @patch("sagemaker.image_uris.retrieve") diff --git a/tests/unit/sagemaker/serve/builder/test_tei_builder.py b/tests/unit/sagemaker/serve/builder/test_tei_builder.py new file mode 100644 index 0000000000..4a75174bfc --- /dev/null +++ b/tests/unit/sagemaker/serve/builder/test_tei_builder.py @@ -0,0 +1,152 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import +from unittest.mock import MagicMock, patch + +import unittest +from sagemaker.serve.builder.model_builder import ModelBuilder +from sagemaker.serve.mode.function_pointers import Mode +from tests.unit.sagemaker.serve.constants import MOCK_VPC_CONFIG + +from sagemaker.serve.utils.predictors import TeiLocalModePredictor + +mock_model_id = "bert-base-uncased" +mock_prompt = "The man worked as a [MASK]." +mock_sample_input = {"inputs": mock_prompt} +mock_sample_output = [ + { + "score": 0.0974755585193634, + "token": 10533, + "token_str": "carpenter", + "sequence": "the man worked as a carpenter.", + }, + { + "score": 0.052383411675691605, + "token": 15610, + "token_str": "waiter", + "sequence": "the man worked as a waiter.", + }, + { + "score": 0.04962712526321411, + "token": 13362, + "token_str": "barber", + "sequence": "the man worked as a barber.", + }, + { + "score": 0.0378861166536808, + "token": 15893, + "token_str": "mechanic", + "sequence": "the man worked as a mechanic.", + }, + { + "score": 0.037680838257074356, + "token": 18968, + "token_str": "salesman", + "sequence": "the man worked as a salesman.", + }, +] +mock_schema_builder = MagicMock() +mock_schema_builder.sample_input = mock_sample_input +mock_schema_builder.sample_output = mock_sample_output +MOCK_IMAGE_CONFIG = ( + "763104351884.dkr.ecr.us-west-2.amazonaws.com/" + "huggingface-pytorch-inference:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04-v1.0" +) + + +class TestTEIBuilder(unittest.TestCase): + @patch( + "sagemaker.serve.builder.tei_builder._get_nb_instance", + return_value="ml.g5.24xlarge", + ) + @patch("sagemaker.serve.builder.tei_builder._capture_telemetry", side_effect=None) + def test_build_deploy_for_tei_local_container_and_remote_container( + self, + mock_get_nb_instance, + mock_telemetry, + ): + builder = ModelBuilder( + model=mock_model_id, + schema_builder=mock_schema_builder, + mode=Mode.LOCAL_CONTAINER, + vpc_config=MOCK_VPC_CONFIG, + model_metadata={ + "HF_TASK": "sentence-similarity", + }, + ) + + builder._prepare_for_mode = MagicMock() + builder._prepare_for_mode.side_effect = None + + model = builder.build() + builder.serve_settings.telemetry_opt_out = True + + builder.modes[str(Mode.LOCAL_CONTAINER)] = MagicMock() + predictor = model.deploy(model_data_download_timeout=1800) + + assert model.vpc_config == MOCK_VPC_CONFIG + assert builder.env_vars["MODEL_LOADING_TIMEOUT"] == "1800" + assert isinstance(predictor, TeiLocalModePredictor) + + assert builder.nb_instance_type == "ml.g5.24xlarge" + + builder._original_deploy = MagicMock() + builder._prepare_for_mode.return_value = (None, {}) + predictor = model.deploy(mode=Mode.SAGEMAKER_ENDPOINT, role="mock_role_arn") + assert "HF_MODEL_ID" in model.env + + with self.assertRaises(ValueError) as _: + model.deploy(mode=Mode.IN_PROCESS) + + @patch( + "sagemaker.serve.builder.tei_builder._get_nb_instance", + return_value="ml.g5.24xlarge", + ) + @patch("sagemaker.serve.builder.tei_builder._capture_telemetry", side_effect=None) + def test_image_uri_override( + self, + mock_get_nb_instance, + mock_telemetry, + ): + builder = ModelBuilder( + model=mock_model_id, + schema_builder=mock_schema_builder, + mode=Mode.LOCAL_CONTAINER, + image_uri=MOCK_IMAGE_CONFIG, + model_metadata={ + "HF_TASK": "sentence-similarity", + }, + ) + + builder._prepare_for_mode = MagicMock() + builder._prepare_for_mode.side_effect = None + + model = builder.build() + builder.serve_settings.telemetry_opt_out = True + + builder.modes[str(Mode.LOCAL_CONTAINER)] = MagicMock() + predictor = model.deploy(model_data_download_timeout=1800) + + assert builder.image_uri == MOCK_IMAGE_CONFIG + assert builder.env_vars["MODEL_LOADING_TIMEOUT"] == "1800" + assert isinstance(predictor, TeiLocalModePredictor) + + assert builder.nb_instance_type == "ml.g5.24xlarge" + + builder._original_deploy = MagicMock() + builder._prepare_for_mode.return_value = (None, {}) + predictor = model.deploy(mode=Mode.SAGEMAKER_ENDPOINT, role="mock_role_arn") + assert "HF_MODEL_ID" in model.env + + with self.assertRaises(ValueError) as _: + model.deploy(mode=Mode.IN_PROCESS) diff --git a/tests/unit/sagemaker/serve/builder/test_transformers_builder.py b/tests/unit/sagemaker/serve/builder/test_transformers_builder.py index b7e3db79d6..9ea797adc2 100644 --- a/tests/unit/sagemaker/serve/builder/test_transformers_builder.py +++ b/tests/unit/sagemaker/serve/builder/test_transformers_builder.py @@ -110,7 +110,7 @@ def test_build_deploy_for_transformers_local_container_and_remote_container( return_value="ml.g5.24xlarge", ) @patch("sagemaker.serve.builder.transformers_builder._capture_telemetry", side_effect=None) - def test_image_uri( + def test_image_uri_override( self, mock_get_nb_instance, mock_telemetry, @@ -144,3 +144,29 @@ def test_image_uri( with self.assertRaises(ValueError) as _: model.deploy(mode=Mode.IN_PROCESS) + + @patch("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_transformers") + @patch( + "sagemaker.serve.builder.transformers_builder._get_nb_instance", + return_value="ml.g5.24xlarge", + ) + @patch("sagemaker.serve.builder.transformers_builder._capture_telemetry", side_effect=None) + @patch( + "sagemaker.huggingface.llm_utils.get_huggingface_model_metadata", + return_value=None, + ) + def test_failure_hf_md( + self, mock_model_md, mock_get_nb_instance, mock_telemetry, mock_build_for_transformers + ): + builder = ModelBuilder( + model=mock_model_id, + schema_builder=mock_schema_builder, + mode=Mode.LOCAL_CONTAINER, + ) + + builder._prepare_for_mode = MagicMock() + builder._prepare_for_mode.side_effect = None + + builder.build() + + mock_build_for_transformers.assert_called_once() diff --git a/tests/unit/sagemaker/serve/model_server/tei/__init__.py b/tests/unit/sagemaker/serve/model_server/tei/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/unit/sagemaker/serve/model_server/tei/test_server.py b/tests/unit/sagemaker/serve/model_server/tei/test_server.py new file mode 100644 index 0000000000..16dcf12b5a --- /dev/null +++ b/tests/unit/sagemaker/serve/model_server/tei/test_server.py @@ -0,0 +1,150 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +from pathlib import PosixPath +from unittest import TestCase +from unittest.mock import Mock, patch + +from docker.types import DeviceRequest +from sagemaker.serve.model_server.tei.server import LocalTeiServing, SageMakerTeiServing +from sagemaker.serve.utils.exceptions import LocalModelInvocationException + +TEI_IMAGE = ( + "246618743249.dkr.ecr.us-west-2.amazonaws.com/tei:2.0.1-tei1.2.3-gpu-py310-cu122-ubuntu22.04" +) +MODEL_PATH = "model_path" +ENV_VAR = {"KEY": "VALUE"} +PAYLOAD = { + "inputs": { + "sourceSentence": "How cute your dog is!", + "sentences": ["The mitochondria is the powerhouse of the cell.", "Your dog is so cute."], + } +} +S3_URI = "s3://mock_model_data_uri" +SECRET_KEY = "secret_key" +INFER_RESPONSE = [] + + +class TeiServerTests(TestCase): + @patch("sagemaker.serve.model_server.tei.server.requests") + def test_start_invoke_destroy_local_tei_server(self, mock_requests): + mock_container = Mock() + mock_docker_client = Mock() + mock_docker_client.containers.run.return_value = mock_container + + local_tei_server = LocalTeiServing() + mock_schema_builder = Mock() + mock_schema_builder.input_serializer.serialize.return_value = PAYLOAD + local_tei_server.schema_builder = mock_schema_builder + + local_tei_server._start_tei_serving( + client=mock_docker_client, + model_path=MODEL_PATH, + secret_key=SECRET_KEY, + image=TEI_IMAGE, + env_vars=ENV_VAR, + ) + + mock_docker_client.containers.run.assert_called_once_with( + TEI_IMAGE, + shm_size="2G", + device_requests=[DeviceRequest(count=-1, capabilities=[["gpu"]])], + network_mode="host", + detach=True, + auto_remove=True, + volumes={PosixPath("model_path/code"): {"bind": "/opt/ml/model/", "mode": "rw"}}, + environment={ + "TRANSFORMERS_CACHE": "/opt/ml/model/", + "HUGGINGFACE_HUB_CACHE": "/opt/ml/model/", + "KEY": "VALUE", + "SAGEMAKER_SERVE_SECRET_KEY": "secret_key", + }, + ) + + mock_response = Mock() + mock_requests.post.side_effect = lambda *args, **kwargs: mock_response + mock_response.content = INFER_RESPONSE + + res = local_tei_server._invoke_tei_serving( + request=PAYLOAD, content_type="application/json", accept="application/json" + ) + + self.assertEqual(res, INFER_RESPONSE) + + def test_tei_deep_ping(self): + mock_predictor = Mock() + mock_response = Mock() + mock_schema_builder = Mock() + + mock_predictor.predict.side_effect = lambda *args, **kwargs: mock_response + mock_schema_builder.sample_input = PAYLOAD + + local_tei_server = LocalTeiServing() + local_tei_server.schema_builder = mock_schema_builder + res = local_tei_server._tei_deep_ping(mock_predictor) + + self.assertEqual(res, (True, mock_response)) + + def test_tei_deep_ping_invoke_ex(self): + mock_predictor = Mock() + mock_schema_builder = Mock() + + mock_predictor.predict.side_effect = lambda *args, **kwargs: exec( + 'raise(ValueError("422 Client Error: Unprocessable Entity for url:"))' + ) + mock_schema_builder.sample_input = PAYLOAD + + local_tei_server = LocalTeiServing() + local_tei_server.schema_builder = mock_schema_builder + + self.assertRaises( + LocalModelInvocationException, lambda: local_tei_server._tei_deep_ping(mock_predictor) + ) + + def test_tei_deep_ping_ex(self): + mock_predictor = Mock() + + mock_predictor.predict.side_effect = lambda *args, **kwargs: Exception() + + local_tei_server = LocalTeiServing() + res = local_tei_server._tei_deep_ping(mock_predictor) + + self.assertEqual(res, (False, None)) + + @patch("sagemaker.serve.model_server.tei.server.S3Uploader") + def test_upload_artifacts_sagemaker_tei_server(self, mock_uploader): + mock_session = Mock() + mock_uploader.upload.side_effect = ( + lambda *args, **kwargs: "s3://sagemaker-us-west-2-123456789123/tei-2024-05-20-16-05-36-027/code" + ) + + s3_upload_path, env_vars = SageMakerTeiServing()._upload_tei_artifacts( + model_path=MODEL_PATH, + sagemaker_session=mock_session, + s3_model_data_url=S3_URI, + image=TEI_IMAGE, + ) + + mock_uploader.upload.assert_called_once() + self.assertEqual( + s3_upload_path, + { + "S3DataSource": { + "CompressionType": "None", + "S3DataType": "S3Prefix", + "S3Uri": "s3://sagemaker-us-west-2-123456789123/tei-2024-05-20-16-05-36-027/code/", + } + }, + ) + self.assertIsNotNone(env_vars) diff --git a/tests/unit/test_fw_utils.py b/tests/unit/test_fw_utils.py index e955d68227..97d4e6ec2a 100644 --- a/tests/unit/test_fw_utils.py +++ b/tests/unit/test_fw_utils.py @@ -854,17 +854,14 @@ def test_validate_smdataparallel_args_raises(): # Cases {PT|TF2} # 1. None instance type - # 2. incorrect instance type - # 3. incorrect python version - # 4. incorrect framework version + # 2. incorrect python version + # 3. incorrect framework version bad_args = [ (None, "tensorflow", "2.3.1", "py3", smdataparallel_enabled), - ("ml.p3.2xlarge", "tensorflow", "2.3.1", "py3", smdataparallel_enabled), ("ml.p3dn.24xlarge", "tensorflow", "2.3.1", "py2", smdataparallel_enabled), ("ml.p3.16xlarge", "tensorflow", "1.3.1", "py3", smdataparallel_enabled), (None, "pytorch", "1.6.0", "py3", smdataparallel_enabled), - ("ml.p3.2xlarge", "pytorch", "1.6.0", "py3", smdataparallel_enabled), ("ml.p3dn.24xlarge", "pytorch", "1.6.0", "py2", smdataparallel_enabled), ("ml.p3.16xlarge", "pytorch", "1.5.0", "py3", smdataparallel_enabled), ] @@ -966,74 +963,6 @@ def test_validate_smdataparallel_args_not_raises(): ) -def test_validate_pytorchddp_not_raises(): - # Case 1: Framework is not PyTorch - fw_utils.validate_pytorch_distribution( - distribution=None, - framework_name="tensorflow", - framework_version="2.9.1", - py_version="py3", - image_uri="custom-container", - ) - # Case 2: Framework is PyTorch, but distribution is not PyTorchDDP - pytorchddp_disabled = {"pytorchddp": {"enabled": False}} - fw_utils.validate_pytorch_distribution( - distribution=pytorchddp_disabled, - framework_name="pytorch", - framework_version="1.10", - py_version="py3", - image_uri="custom-container", - ) - # Case 3: Framework is PyTorch, Distribution is PyTorchDDP enabled, supported framework and py versions - pytorchddp_enabled = {"pytorchddp": {"enabled": True}} - pytorchddp_supported_fw_versions = [ - "1.10", - "1.10.0", - "1.10.2", - "1.11", - "1.11.0", - "1.12", - "1.12.0", - "1.12.1", - "1.13.1", - "2.0.0", - "2.0.1", - "2.1.0", - "2.2.0", - ] - for framework_version in pytorchddp_supported_fw_versions: - fw_utils.validate_pytorch_distribution( - distribution=pytorchddp_enabled, - framework_name="pytorch", - framework_version=framework_version, - py_version="py3", - image_uri="custom-container", - ) - - -def test_validate_pytorchddp_raises(): - pytorchddp_enabled = {"pytorchddp": {"enabled": True}} - # Case 1: Unsupported framework version - with pytest.raises(ValueError): - fw_utils.validate_pytorch_distribution( - distribution=pytorchddp_enabled, - framework_name="pytorch", - framework_version="1.8", - py_version="py3", - image_uri=None, - ) - - # Case 2: Unsupported Py version - with pytest.raises(ValueError): - fw_utils.validate_pytorch_distribution( - distribution=pytorchddp_enabled, - framework_name="pytorch", - framework_version="1.10", - py_version="py2", - image_uri=None, - ) - - def test_validate_torch_distributed_not_raises(): # Case 1: Framework is PyTorch, but torch_distributed is not enabled torch_distributed_disabled = {"torch_distributed": {"enabled": False}} diff --git a/tests/unit/test_pytorch.py b/tests/unit/test_pytorch.py index 5ada026ef8..618d0d7ea8 100644 --- a/tests/unit/test_pytorch.py +++ b/tests/unit/test_pytorch.py @@ -801,14 +801,15 @@ def test_pytorch_ddp_distribution_configuration( distribution=pytorch.distribution ) expected_torch_ddp = { - "sagemaker_pytorch_ddp_enabled": True, + "sagemaker_distributed_dataparallel_enabled": True, + "sagemaker_distributed_dataparallel_custom_mpi_options": "", "sagemaker_instance_type": test_instance_type, } assert actual_pytorch_ddp == expected_torch_ddp def test_pytorch_ddp_distribution_configuration_unsupported(sagemaker_session): - unsupported_framework_version = "1.9.1" + unsupported_framework_version = "1.5.0" unsupported_py_version = "py2" with pytest.raises(ValueError) as error: _pytorch_estimator( diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index 93abcfc7a8..d5214d01c3 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -30,6 +30,7 @@ from mock import call, patch, Mock, MagicMock, PropertyMock import sagemaker +from sagemaker.enums import RoutingStrategy from sagemaker.experiments._run_context import _RunContext from sagemaker.session_settings import SessionSettings from sagemaker.utils import ( @@ -52,6 +53,7 @@ can_model_package_source_uri_autopopulate, get_instance_rate_per_hour, extract_instance_rate_per_hour, + _resolve_routing_config, ) from tests.unit.sagemaker.workflow.helpers import CustomStep from sagemaker.workflow.parameters import ParameterString, ParameterInteger @@ -2014,3 +2016,30 @@ def test_extract_instance_rate_per_hour(price_data, expected_result): out = extract_instance_rate_per_hour(price_data) assert out == expected_result + + +@pytest.mark.parametrize( + "routing_config, expected", + [ + ({"RoutingStrategy": RoutingStrategy.RANDOM}, {"RoutingStrategy": "RANDOM"}), + ({"RoutingStrategy": "RANDOM"}, {"RoutingStrategy": "RANDOM"}), + ( + {"RoutingStrategy": RoutingStrategy.LEAST_OUTSTANDING_REQUESTS}, + {"RoutingStrategy": "LEAST_OUTSTANDING_REQUESTS"}, + ), + ( + {"RoutingStrategy": "LEAST_OUTSTANDING_REQUESTS"}, + {"RoutingStrategy": "LEAST_OUTSTANDING_REQUESTS"}, + ), + ({"RoutingStrategy": None}, None), + (None, None), + ], +) +def test_resolve_routing_config(routing_config, expected): + res = _resolve_routing_config(routing_config) + + assert res == expected + + +def test_resolve_routing_config_ex(): + pytest.raises(ValueError, lambda: _resolve_routing_config({"RoutingStrategy": "Invalid"})) diff --git a/tox.ini b/tox.ini index 718e968013..6e1f9ce956 100644 --- a/tox.ini +++ b/tox.ini @@ -81,7 +81,7 @@ passenv = # Can be used to specify which tests to run, e.g.: tox -- -s commands = python -c "import os; os.system('install-custom-pkgs --install-boto-wheels')" - pip install 'apache-airflow==2.9.0' --constraint "https://raw.githubusercontent.com/apache/airflow/constraints-2.9.0/constraints-3.8.txt" + pip install 'apache-airflow==2.9.1' --constraint "https://raw.githubusercontent.com/apache/airflow/constraints-2.9.1/constraints-3.8.txt" pip install 'torch==2.0.1+cpu' -f 'https://download.pytorch.org/whl/torch_stable.html' pip install 'torchvision==0.15.2+cpu' -f 'https://download.pytorch.org/whl/torch_stable.html' pip install 'dill>=0.3.8'