# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
import botocore
import datetime
import time
import functools
from pydantic import validate_call
from typing import Dict, List, Literal, Optional, Union, Any
from boto3.session import Session
from rich.console import Group
from rich.live import Live
from rich.panel import Panel
from rich.progress import Progress, SpinnerColumn, TextColumn, TimeElapsedColumn
from rich.status import Status
from rich.style import Style
from sagemaker_core.main.code_injection.codec import transform
from sagemaker_core.main.code_injection.constants import Color
from sagemaker_core.main.utils import (
SageMakerClient,
ResourceIterator,
Unassigned,
get_textual_rich_logger,
snake_to_pascal,
pascal_to_snake,
is_not_primitive,
is_not_str_dict,
is_primitive_list,
serialize,
)
from sagemaker_core.main.intelligent_defaults_helper import (
load_default_configs_for_resource_name,
get_config_value,
)
from sagemaker_core.main.logs import MultiLogStreamHandler
from sagemaker_core.main.shapes import *
from sagemaker_core.main.exceptions import *
logger = get_textual_rich_logger(__name__)
[docs]
class Base(BaseModel):
model_config = ConfigDict(protected_namespaces=(), validate_assignment=True, extra="forbid")
@classmethod
def get_sagemaker_client(cls, session=None, region_name=None, service_name="sagemaker"):
return SageMakerClient(session=session, region_name=region_name).get_client(
service_name=service_name
)
@staticmethod
def get_updated_kwargs_with_configured_attributes(
config_schema_for_resource: dict, resource_name: str, **kwargs
):
try:
for configurable_attribute in config_schema_for_resource:
if kwargs.get(configurable_attribute) is None:
resource_defaults = load_default_configs_for_resource_name(
resource_name=resource_name
)
global_defaults = load_default_configs_for_resource_name(
resource_name="GlobalDefaults"
)
if config_value := get_config_value(
configurable_attribute, resource_defaults, global_defaults
):
resource_name = snake_to_pascal(configurable_attribute)
class_object = globals()[resource_name]
kwargs[configurable_attribute] = class_object(**config_value)
except BaseException as e:
logger.debug("Could not load Default Configs. Continuing.", exc_info=True)
# Continue with existing kwargs if no default configs found
return kwargs
@staticmethod
def populate_chained_attributes(resource_name: str, operation_input_args: Union[dict, object]):
resource_name_in_snake_case = pascal_to_snake(resource_name)
updated_args = (
vars(operation_input_args)
if type(operation_input_args) == object
else operation_input_args
)
unassigned_args = []
keys = operation_input_args.keys()
for arg in keys:
value = operation_input_args.get(arg)
arg_snake = pascal_to_snake(arg)
if value == Unassigned():
unassigned_args.append(arg)
elif value == None or not value:
continue
elif (
arg_snake.endswith("name")
and arg_snake[: -len("_name")] != resource_name_in_snake_case
and arg_snake != "name"
):
if value and value != Unassigned() and type(value) != str:
updated_args[arg] = value.get_name()
elif isinstance(value, list) and is_primitive_list(value):
continue
elif isinstance(value, list) and value != []:
updated_args[arg] = [Base._get_chained_attribute(list_item) for list_item in value]
elif is_not_primitive(value) and is_not_str_dict(value) and type(value) == object:
updated_args[arg] = Base._get_chained_attribute(item_value=value)
for unassigned_arg in unassigned_args:
del updated_args[unassigned_arg]
return updated_args
@staticmethod
def _get_chained_attribute(item_value: Any):
resource_name = type(item_value).__name__
class_object = globals()[resource_name]
return class_object(
**Base.populate_chained_attributes(
resource_name=resource_name, operation_input_args=vars(item_value)
)
)
@staticmethod
def add_validate_call(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
config = dict(arbitrary_types_allowed=True)
return validate_call(config=config)(func)(*args, **kwargs)
return wrapper
[docs]
class Action(Base):
"""
Class representing resource Action
Attributes:
action_name: The name of the action.
action_arn: The Amazon Resource Name (ARN) of the action.
source: The source of the action.
action_type: The type of the action.
description: The description of the action.
status: The status of the action.
properties: A list of the action's properties.
creation_time: When the action was created.
created_by:
last_modified_time: When the action was last modified.
last_modified_by:
metadata_properties:
lineage_group_arn: The Amazon Resource Name (ARN) of the lineage group.
"""
action_name: str
action_arn: Optional[str] = Unassigned()
source: Optional[ActionSource] = Unassigned()
action_type: Optional[str] = Unassigned()
description: Optional[str] = Unassigned()
status: Optional[str] = Unassigned()
properties: Optional[Dict[str, str]] = Unassigned()
creation_time: Optional[datetime.datetime] = Unassigned()
created_by: Optional[UserContext] = Unassigned()
last_modified_time: Optional[datetime.datetime] = Unassigned()
last_modified_by: Optional[UserContext] = Unassigned()
metadata_properties: Optional[MetadataProperties] = Unassigned()
lineage_group_arn: Optional[str] = Unassigned()
def get_name(self) -> str:
attributes = vars(self)
resource_name = "action_name"
resource_name_split = resource_name.split("_")
attribute_name_candidates = []
l = len(resource_name_split)
for i in range(0, l):
attribute_name_candidates.append("_".join(resource_name_split[i:l]))
for attribute, value in attributes.items():
if attribute == "name" or attribute in attribute_name_candidates:
return value
logger.error("Name attribute not found for object action")
return None
[docs]
@classmethod
@Base.add_validate_call
def create(
cls,
action_name: str,
source: ActionSource,
action_type: str,
description: Optional[str] = Unassigned(),
status: Optional[str] = Unassigned(),
properties: Optional[Dict[str, str]] = Unassigned(),
metadata_properties: Optional[MetadataProperties] = Unassigned(),
tags: Optional[List[Tag]] = Unassigned(),
session: Optional[Session] = None,
region: Optional[str] = None,
) -> Optional["Action"]:
"""
Create a Action resource
Parameters:
action_name: The name of the action. Must be unique to your account in an Amazon Web Services Region.
source: The source type, ID, and URI.
action_type: The action type.
description: The description of the action.
status: The status of the action.
properties: A list of properties to add to the action.
metadata_properties:
tags: A list of tags to apply to the action.
session: Boto3 session.
region: Region name.
Returns:
The Action resource.
Raises:
botocore.exceptions.ClientError: This exception is raised for AWS service related errors.
The error message and error code can be parsed from the exception as follows:
```
try:
# AWS service call here
except botocore.exceptions.ClientError as e:
error_message = e.response['Error']['Message']
error_code = e.response['Error']['Code']
```
ResourceLimitExceeded: You have exceeded an SageMaker resource limit. For example, you might have too many training jobs created.
ConfigSchemaValidationError: Raised when a configuration file does not adhere to the schema
LocalConfigNotFoundError: Raised when a configuration file is not found in local file system
S3ConfigNotFoundError: Raised when a configuration file is not found in S3
"""
logger.info("Creating action resource.")
client = Base.get_sagemaker_client(
session=session, region_name=region, service_name="sagemaker"
)
operation_input_args = {
"ActionName": action_name,
"Source": source,
"ActionType": action_type,
"Description": description,
"Status": status,
"Properties": properties,
"MetadataProperties": metadata_properties,
"Tags": tags,
}
operation_input_args = Base.populate_chained_attributes(
resource_name="Action", operation_input_args=operation_input_args
)
logger.debug(f"Input request: {operation_input_args}")
# serialize the input request
operation_input_args = serialize(operation_input_args)
logger.debug(f"Serialized input request: {operation_input_args}")
# create the resource
response = client.create_action(**operation_input_args)
logger.debug(f"Response: {response}")
return cls.get(action_name=action_name, session=session, region=region)
[docs]
@classmethod
@Base.add_validate_call
def get(
cls,
action_name: str,
session: Optional[Session] = None,
region: Optional[str] = None,
) -> Optional["Action"]:
"""
Get a Action resource
Parameters:
action_name: The name of the action to describe.
session: Boto3 session.
region: Region name.
Returns:
The Action resource.
Raises:
botocore.exceptions.ClientError: This exception is raised for AWS service related errors.
The error message and error code can be parsed from the exception as follows:
```
try:
# AWS service call here
except botocore.exceptions.ClientError as e:
error_message = e.response['Error']['Message']
error_code = e.response['Error']['Code']
```
ResourceNotFound: Resource being access is not found.
"""
operation_input_args = {
"ActionName": action_name,
}
# serialize the input request
operation_input_args = serialize(operation_input_args)
logger.debug(f"Serialized input request: {operation_input_args}")
client = Base.get_sagemaker_client(
session=session, region_name=region, service_name="sagemaker"
)
response = client.describe_action(**operation_input_args)
logger.debug(response)
# deserialize the response
transformed_response = transform(response, "DescribeActionResponse")
action = cls(**transformed_response)
return action
[docs]
@Base.add_validate_call
def refresh(
self,
) -> Optional["Action"]:
"""
Refresh a Action resource
Returns:
The Action resource.
Raises:
botocore.exceptions.ClientError: This exception is raised for AWS service related errors.
The error message and error code can be parsed from the exception as follows:
```
try:
# AWS service call here
except botocore.exceptions.ClientError as e:
error_message = e.response['Error']['Message']
error_code = e.response['Error']['Code']
```
ResourceNotFound: Resource being access is not found.
"""
operation_input_args = {
"ActionName": self.action_name,
}
# serialize the input request
operation_input_args = serialize(operation_input_args)
logger.debug(f"Serialized input request: {operation_input_args}")
client = Base.get_sagemaker_client()
response = client.describe_action(**operation_input_args)
# deserialize response and update self
transform(response, "DescribeActionResponse", self)
return self
[docs]
@Base.add_validate_call
def update(
self,
description: Optional[str] = Unassigned(),
status: Optional[str] = Unassigned(),
properties: Optional[Dict[str, str]] = Unassigned(),
properties_to_remove: Optional[List[str]] = Unassigned(),
) -> Optional["Action"]:
"""
Update a Action resource
Parameters:
properties_to_remove: A list of properties to remove.
Returns:
The Action resource.
Raises:
botocore.exceptions.ClientError: This exception is raised for AWS service related errors.
The error message and error code can be parsed from the exception as follows:
```
try:
# AWS service call here
except botocore.exceptions.ClientError as e:
error_message = e.response['Error']['Message']
error_code = e.response['Error']['Code']
```
ConflictException: There was a conflict when you attempted to modify a SageMaker entity such as an Experiment or Artifact.
ResourceNotFound: Resource being access is not found.
"""
logger.info("Updating action resource.")
client = Base.get_sagemaker_client()
operation_input_args = {
"ActionName": self.action_name,
"Description": description,
"Status": status,
"Properties": properties,
"PropertiesToRemove": properties_to_remove,
}
logger.debug(f"Input request: {operation_input_args}")
# serialize the input request
operation_input_args = serialize(operation_input_args)
logger.debug(f"Serialized input request: {operation_input_args}")
# create the resource
response = client.update_action(**operation_input_args)
logger.debug(f"Response: {response}")
self.refresh()
return self
[docs]
@Base.add_validate_call
def delete(
self,
) -> None:
"""
Delete a Action resource
Raises:
botocore.exceptions.ClientError: This exception is raised for AWS service related errors.
The error message and error code can be parsed from the exception as follows:
```
try:
# AWS service call here
except botocore.exceptions.ClientError as e:
error_message = e.response['Error']['Message']
error_code = e.response['Error']['Code']
```
ResourceNotFound: Resource being access is not found.
"""
client = Base.get_sagemaker_client()
operation_input_args = {
"ActionName": self.action_name,
}
# serialize the input request
operation_input_args = serialize(operation_input_args)
logger.debug(f"Serialized input request: {operation_input_args}")
client.delete_action(**operation_input_args)
logger.info(f"Deleting {self.__class__.__name__} - {self.get_name()}")
[docs]
@classmethod
@Base.add_validate_call
def get_all(
cls,
source_uri: Optional[str] = Unassigned(),
action_type: Optional[str] = Unassigned(),
created_after: Optional[datetime.datetime] = Unassigned(),
created_before: Optional[datetime.datetime] = Unassigned(),
sort_by: Optional[str] = Unassigned(),
sort_order: Optional[str] = Unassigned(),
session: Optional[Session] = None,
region: Optional[str] = None,
) -> ResourceIterator["Action"]:
"""
Get all Action resources
Parameters:
source_uri: A filter that returns only actions with the specified source URI.
action_type: A filter that returns only actions of the specified type.
created_after: A filter that returns only actions created on or after the specified time.
created_before: A filter that returns only actions created on or before the specified time.
sort_by: The property used to sort results. The default value is CreationTime.
sort_order: The sort order. The default value is Descending.
next_token: If the previous call to ListActions didn't return the full set of actions, the call returns a token for getting the next set of actions.
max_results: The maximum number of actions to return in the response. The default value is 10.
session: Boto3 session.
region: Region name.
Returns:
Iterator for listed Action resources.
Raises:
botocore.exceptions.ClientError: This exception is raised for AWS service related errors.
The error message and error code can be parsed from the exception as follows:
```
try:
# AWS service call here
except botocore.exceptions.ClientError as e:
error_message = e.response['Error']['Message']
error_code = e.response['Error']['Code']
```
ResourceNotFound: Resource being access is not found.
"""
client = Base.get_sagemaker_client(
session=session, region_name=region, service_name="sagemaker"
)
operation_input_args = {
"SourceUri": source_uri,
"ActionType": action_type,
"CreatedAfter": created_after,
"CreatedBefore": created_before,
"SortBy": sort_by,
"SortOrder": sort_order,
}
# serialize the input request
operation_input_args = serialize(operation_input_args)
logger.debug(f"Serialized input request: {operation_input_args}")
return ResourceIterator(
client=client,
list_method="list_actions",
summaries_key="ActionSummaries",
summary_name="ActionSummary",
resource_cls=Action,
list_method_kwargs=operation_input_args,
)
[docs]
class Algorithm(Base):
"""
Class representing resource Algorithm
Attributes:
algorithm_name: The name of the algorithm being described.
algorithm_arn: The Amazon Resource Name (ARN) of the algorithm.
creation_time: A timestamp specifying when the algorithm was created.
training_specification: Details about training jobs run by this algorithm.
algorithm_status: The current status of the algorithm.
algorithm_status_details: Details about the current status of the algorithm.
algorithm_description: A brief summary about the algorithm.
inference_specification: Details about inference jobs that the algorithm runs.
validation_specification: Details about configurations for one or more training jobs that SageMaker runs to test the algorithm.
product_id: The product identifier of the algorithm.
certify_for_marketplace: Whether the algorithm is certified to be listed in Amazon Web Services Marketplace.
"""
algorithm_name: str
algorithm_arn: Optional[str] = Unassigned()
algorithm_description: Optional[str] = Unassigned()
creation_time: Optional[datetime.datetime] = Unassigned()
training_specification: Optional[TrainingSpecification] = Unassigned()
inference_specification: Optional[InferenceSpecification] = Unassigned()
validation_specification: Optional[AlgorithmValidationSpecification] = Unassigned()
algorithm_status: Optional[str] = Unassigned()
algorithm_status_details: Optional[AlgorithmStatusDetails] = Unassigned()
product_id: Optional[str] = Unassigned()
certify_for_marketplace: Optional[bool] = Unassigned()
def get_name(self) -> str:
attributes = vars(self)
resource_name = "algorithm_name"
resource_name_split = resource_name.split("_")
attribute_name_candidates = []
l = len(resource_name_split)
for i in range(0, l):
attribute_name_candidates.append("_".join(resource_name_split[i:l]))
for attribute, value in attributes.items():
if attribute == "name" or attribute in attribute_name_candidates:
return value
logger.error("Name attribute not found for object algorithm")
return None
def populate_inputs_decorator(create_func):
@functools.wraps(create_func)
def wrapper(*args, **kwargs):
config_schema_for_resource = {
"training_specification": {
"additional_s3_data_source": {
"s3_data_type": {"type": "string"},
"s3_uri": {"type": "string"},
}
},
"validation_specification": {"validation_role": {"type": "string"}},
}
return create_func(
*args,
**Base.get_updated_kwargs_with_configured_attributes(
config_schema_for_resource, "Algorithm", **kwargs
),
)
return wrapper
[docs]
@classmethod
@populate_inputs_decorator
@Base.add_validate_call
def create(
cls,
algorithm_name: str,
training_specification: TrainingSpecification,
algorithm_description: Optional[str] = Unassigned(),
inference_specification: Optional[InferenceSpecification] = Unassigned(),
validation_specification: Optional[AlgorithmValidationSpecification] = Unassigned(),
certify_for_marketplace: Optional[bool] = Unassigned(),
tags: Optional[List[Tag]] = Unassigned(),
session: Optional[Session] = None,
region: Optional[str] = None,
) -> Optional["Algorithm"]:
"""
Create a Algorithm resource
Parameters:
algorithm_name: The name of the algorithm.
training_specification: Specifies details about training jobs run by this algorithm, including the following: The Amazon ECR path of the container and the version digest of the algorithm. The hyperparameters that the algorithm supports. The instance types that the algorithm supports for training. Whether the algorithm supports distributed training. The metrics that the algorithm emits to Amazon CloudWatch. Which metrics that the algorithm emits can be used as the objective metric for hyperparameter tuning jobs. The input channels that the algorithm supports for training data. For example, an algorithm might support train, validation, and test channels.
algorithm_description: A description of the algorithm.
inference_specification: Specifies details about inference jobs that the algorithm runs, including the following: The Amazon ECR paths of containers that contain the inference code and model artifacts. The instance types that the algorithm supports for transform jobs and real-time endpoints used for inference. The input and output content formats that the algorithm supports for inference.
validation_specification: Specifies configurations for one or more training jobs and that SageMaker runs to test the algorithm's training code and, optionally, one or more batch transform jobs that SageMaker runs to test the algorithm's inference code.
certify_for_marketplace: Whether to certify the algorithm so that it can be listed in Amazon Web Services Marketplace.
tags: An array of key-value pairs. You can use tags to categorize your Amazon Web Services resources in different ways, for example, by purpose, owner, or environment. For more information, see Tagging Amazon Web Services Resources.
session: Boto3 session.
region: Region name.
Returns:
The Algorithm resource.
Raises:
botocore.exceptions.ClientError: This exception is raised for AWS service related errors.
The error message and error code can be parsed from the exception as follows:
```
try:
# AWS service call here
except botocore.exceptions.ClientError as e:
error_message = e.response['Error']['Message']
error_code = e.response['Error']['Code']
```
ConfigSchemaValidationError: Raised when a configuration file does not adhere to the schema
LocalConfigNotFoundError: Raised when a configuration file is not found in local file system
S3ConfigNotFoundError: Raised when a configuration file is not found in S3
"""
logger.info("Creating algorithm resource.")
client = Base.get_sagemaker_client(
session=session, region_name=region, service_name="sagemaker"
)
operation_input_args = {
"AlgorithmName": algorithm_name,
"AlgorithmDescription": algorithm_description,
"TrainingSpecification": training_specification,
"InferenceSpecification": inference_specification,
"ValidationSpecification": validation_specification,
"CertifyForMarketplace": certify_for_marketplace,
"Tags": tags,
}
operation_input_args = Base.populate_chained_attributes(
resource_name="Algorithm", operation_input_args=operation_input_args
)
logger.debug(f"Input request: {operation_input_args}")
# serialize the input request
operation_input_args = serialize(operation_input_args)
logger.debug(f"Serialized input request: {operation_input_args}")
# create the resource
response = client.create_algorithm(**operation_input_args)
logger.debug(f"Response: {response}")
return cls.get(algorithm_name=algorithm_name, session=session, region=region)
[docs]
@classmethod
@Base.add_validate_call
def get(
cls,
algorithm_name: str,
session: Optional[Session] = None,
region: Optional[str] = None,
) -> Optional["Algorithm"]:
"""
Get a Algorithm resource
Parameters:
algorithm_name: The name of the algorithm to describe.
session: Boto3 session.
region: Region name.
Returns:
The Algorithm resource.
Raises:
botocore.exceptions.ClientError: This exception is raised for AWS service related errors.
The error message and error code can be parsed from the exception as follows:
```
try:
# AWS service call here
except botocore.exceptions.ClientError as e:
error_message = e.response['Error']['Message']
error_code = e.response['Error']['Code']
```
"""
operation_input_args = {
"AlgorithmName": algorithm_name,
}
# serialize the input request
operation_input_args = serialize(operation_input_args)
logger.debug(f"Serialized input request: {operation_input_args}")
client = Base.get_sagemaker_client(
session=session, region_name=region, service_name="sagemaker"
)
response = client.describe_algorithm(**operation_input_args)
logger.debug(response)
# deserialize the response
transformed_response = transform(response, "DescribeAlgorithmOutput")
algorithm = cls(**transformed_response)
return algorithm
[docs]
@Base.add_validate_call
def refresh(
self,
) -> Optional["Algorithm"]:
"""
Refresh a Algorithm resource
Returns:
The Algorithm resource.
Raises:
botocore.exceptions.ClientError: This exception is raised for AWS service related errors.
The error message and error code can be parsed from the exception as follows:
```
try:
# AWS service call here
except botocore.exceptions.ClientError as e:
error_message = e.response['Error']['Message']
error_code = e.response['Error']['Code']
```
"""
operation_input_args = {
"AlgorithmName": self.algorithm_name,
}
# serialize the input request
operation_input_args = serialize(operation_input_args)
logger.debug(f"Serialized input request: {operation_input_args}")
client = Base.get_sagemaker_client()
response = client.describe_algorithm(**operation_input_args)
# deserialize response and update self
transform(response, "DescribeAlgorithmOutput", self)
return self
[docs]
@Base.add_validate_call
def delete(
self,
) -> None:
"""
Delete a Algorithm resource
Raises:
botocore.exceptions.ClientError: This exception is raised for AWS service related errors.
The error message and error code can be parsed from the exception as follows:
```
try:
# AWS service call here
except botocore.exceptions.ClientError as e:
error_message = e.response['Error']['Message']
error_code = e.response['Error']['Code']
```
ConflictException: There was a conflict when you attempted to modify a SageMaker entity such as an Experiment or Artifact.
"""
client = Base.get_sagemaker_client()
operation_input_args = {
"AlgorithmName": self.algorithm_name,
}
# serialize the input request
operation_input_args = serialize(operation_input_args)
logger.debug(f"Serialized input request: {operation_input_args}")
client.delete_algorithm(**operation_input_args)
logger.info(f"Deleting {self.__class__.__name__} - {self.get_name()}")
[docs]
@Base.add_validate_call
def wait_for_status(
self,
target_status: Literal["Pending", "InProgress", "Completed", "Failed", "Deleting"],
poll: int = 5,
timeout: Optional[int] = None,
) -> None:
"""
Wait for a Algorithm resource to reach certain status.
Parameters:
target_status: The status to wait for.
poll: The number of seconds to wait between each poll.
timeout: The maximum number of seconds to wait before timing out.
Raises:
TimeoutExceededError: If the resource does not reach a terminal state before the timeout.
FailedStatusError: If the resource reaches a failed state.
WaiterError: Raised when an error occurs while waiting.
"""
start_time = time.time()
progress = Progress(
SpinnerColumn("bouncingBar"),
TextColumn("{task.description}"),
TimeElapsedColumn(),
)
progress.add_task(f"Waiting for Algorithm to reach [bold]{target_status} status...")
status = Status("Current status:")
with Live(
Panel(
Group(progress, status),
title="Wait Log Panel",
border_style=Style(color=Color.BLUE.value),
),
transient=True,
):
while True:
self.refresh()
current_status = self.algorithm_status
status.update(f"Current status: [bold]{current_status}")
if target_status == current_status:
logger.info(f"Final Resource Status: [bold]{current_status}")
return
if "failed" in current_status.lower():
raise FailedStatusError(
resource_type="Algorithm", status=current_status, reason="(Unknown)"
)
if timeout is not None and time.time() - start_time >= timeout:
raise TimeoutExceededError(resouce_type="Algorithm", status=current_status)
time.sleep(poll)
[docs]
@Base.add_validate_call
def wait_for_delete(
self,
poll: int = 5,
timeout: Optional[int] = None,
) -> None:
"""
Wait for a Algorithm resource to be deleted.
Parameters:
poll: The number of seconds to wait between each poll.
timeout: The maximum number of seconds to wait before timing out.
Raises:
botocore.exceptions.ClientError: This exception is raised for AWS service related errors.
The error message and error code can be parsed from the exception as follows:
```
try:
# AWS service call here
except botocore.exceptions.ClientError as e:
error_message = e.response['Error']['Message']
error_code = e.response['Error']['Code']
```
TimeoutExceededError: If the resource does not reach a terminal state before the timeout.
DeleteFailedStatusError: If the resource reaches a failed state.
WaiterError: Raised when an error occurs while waiting.
"""
start_time = time.time()
progress = Progress(
SpinnerColumn("bouncingBar"),
TextColumn("{task.description}"),
TimeElapsedColumn(),
)
progress.add_task("Waiting for Algorithm to be deleted...")
status = Status("Current status:")
with Live(
Panel(
Group(progress, status),
title="Wait Log Panel",
border_style=Style(color=Color.BLUE.value),
)
):
while True:
try:
self.refresh()
current_status = self.algorithm_status
status.update(f"Current status: [bold]{current_status}")
if timeout is not None and time.time() - start_time >= timeout:
raise TimeoutExceededError(resouce_type="Algorithm", status=current_status)
except botocore.exceptions.ClientError as e:
error_code = e.response["Error"]["Code"]
if "ResourceNotFound" in error_code or "ValidationException" in error_code:
logger.info("Resource was not found. It may have been deleted.")
return
raise e
time.sleep(poll)
[docs]
@classmethod
@Base.add_validate_call
def get_all(
cls,
creation_time_after: Optional[datetime.datetime] = Unassigned(),
creation_time_before: Optional[datetime.datetime] = Unassigned(),
name_contains: Optional[str] = Unassigned(),
sort_by: Optional[str] = Unassigned(),
sort_order: Optional[str] = Unassigned(),
session: Optional[Session] = None,
region: Optional[str] = None,
) -> ResourceIterator["Algorithm"]:
"""
Get all Algorithm resources
Parameters:
creation_time_after: A filter that returns only algorithms created after the specified time (timestamp).
creation_time_before: A filter that returns only algorithms created before the specified time (timestamp).
max_results: The maximum number of algorithms to return in the response.
name_contains: A string in the algorithm name. This filter returns only algorithms whose name contains the specified string.
next_token: If the response to a previous ListAlgorithms request was truncated, the response includes a NextToken. To retrieve the next set of algorithms, use the token in the next request.
sort_by: The parameter by which to sort the results. The default is CreationTime.
sort_order: The sort order for the results. The default is Ascending.
session: Boto3 session.
region: Region name.
Returns:
Iterator for listed Algorithm resources.
Raises:
botocore.exceptions.ClientError: This exception is raised for AWS service related errors.
The error message and error code can be parsed from the exception as follows:
```
try:
# AWS service call here
except botocore.exceptions.ClientError as e:
error_message = e.response['Error']['Message']
error_code = e.response['Error']['Code']
```
"""
client = Base.get_sagemaker_client(
session=session, region_name=region, service_name="sagemaker"
)
operation_input_args = {
"CreationTimeAfter": creation_time_after,
"CreationTimeBefore": creation_time_before,
"NameContains": name_contains,
"SortBy": sort_by,
"SortOrder": sort_order,
}
# serialize the input request
operation_input_args = serialize(operation_input_args)
logger.debug(f"Serialized input request: {operation_input_args}")
return ResourceIterator(
client=client,
list_method="list_algorithms",
summaries_key="AlgorithmSummaryList",
summary_name="AlgorithmSummary",
resource_cls=Algorithm,
list_method_kwargs=operation_input_args,
)
[docs]
class App(Base):
"""
Class representing resource App
Attributes:
app_arn: The Amazon Resource Name (ARN) of the app.
app_type: The type of app.
app_name: The name of the app.
domain_id: The domain ID.
user_profile_name: The user profile name.
space_name: The name of the space. If this value is not set, then UserProfileName must be set.
status: The status.
last_health_check_timestamp: The timestamp of the last health check.
last_user_activity_timestamp: The timestamp of the last user's activity. LastUserActivityTimestamp is also updated when SageMaker performs health checks without user activity. As a result, this value is set to the same value as LastHealthCheckTimestamp.
creation_time: The creation time of the application. After an application has been shut down for 24 hours, SageMaker deletes all metadata for the application. To be considered an update and retain application metadata, applications must be restarted within 24 hours after the previous application has been shut down. After this time window, creation of an application is considered a new application rather than an update of the previous application.
failure_reason: The failure reason.
resource_spec: The instance type and the Amazon Resource Name (ARN) of the SageMaker image created on the instance.
built_in_lifecycle_config_arn: The lifecycle configuration that runs before the default lifecycle configuration
"""
domain_id: str
app_type: str
app_name: str
app_arn: Optional[str] = Unassigned()
user_profile_name: Optional[str] = Unassigned()
space_name: Optional[str] = Unassigned()
status: Optional[str] = Unassigned()
last_health_check_timestamp: Optional[datetime.datetime] = Unassigned()
last_user_activity_timestamp: Optional[datetime.datetime] = Unassigned()
creation_time: Optional[datetime.datetime] = Unassigned()
failure_reason: Optional[str] = Unassigned()
resource_spec: Optional[ResourceSpec] = Unassigned()
built_in_lifecycle_config_arn: Optional[str] = Unassigned()
def get_name(self) -> str:
attributes = vars(self)
resource_name = "app_name"
resource_name_split = resource_name.split("_")
attribute_name_candidates = []
l = len(resource_name_split)
for i in range(0, l):
attribute_name_candidates.append("_".join(resource_name_split[i:l]))
for attribute, value in attributes.items():
if attribute == "name" or attribute in attribute_name_candidates:
return value
logger.error("Name attribute not found for object app")
return None
[docs]
@classmethod
@Base.add_validate_call
def create(
cls,
domain_id: str,
app_type: str,
app_name: str,
user_profile_name: Optional[Union[str, object]] = Unassigned(),
space_name: Optional[Union[str, object]] = Unassigned(),
tags: Optional[List[Tag]] = Unassigned(),
resource_spec: Optional[ResourceSpec] = Unassigned(),
session: Optional[Session] = None,
region: Optional[str] = None,
) -> Optional["App"]:
"""
Create a App resource
Parameters:
domain_id: The domain ID.
app_type: The type of app.
app_name: The name of the app.
user_profile_name: The user profile name. If this value is not set, then SpaceName must be set.
space_name: The name of the space. If this value is not set, then UserProfileName must be set.
tags: Each tag consists of a key and an optional value. Tag keys must be unique per resource.
resource_spec: The instance type and the Amazon Resource Name (ARN) of the SageMaker image created on the instance. The value of InstanceType passed as part of the ResourceSpec in the CreateApp call overrides the value passed as part of the ResourceSpec configured for the user profile or the domain. If InstanceType is not specified in any of those three ResourceSpec values for a KernelGateway app, the CreateApp call fails with a request validation error.
session: Boto3 session.
region: Region name.
Returns:
The App resource.
Raises:
botocore.exceptions.ClientError: This exception is raised for AWS service related errors.
The error message and error code can be parsed from the exception as follows:
```
try:
# AWS service call here
except botocore.exceptions.ClientError as e:
error_message = e.response['Error']['Message']
error_code = e.response['Error']['Code']
```
ResourceInUse: Resource being accessed is in use.
ResourceLimitExceeded: You have exceeded an SageMaker resource limit. For example, you might have too many training jobs created.
ConfigSchemaValidationError: Raised when a configuration file does not adhere to the schema
LocalConfigNotFoundError: Raised when a configuration file is not found in local file system
S3ConfigNotFoundError: Raised when a configuration file is not found in S3
"""
logger.info("Creating app resource.")
client = Base.get_sagemaker_client(
session=session, region_name=region, service_name="sagemaker"
)
operation_input_args = {
"DomainId": domain_id,
"UserProfileName": user_profile_name,
"SpaceName": space_name,
"AppType": app_type,
"AppName": app_name,
"Tags": tags,
"ResourceSpec": resource_spec,
}
operation_input_args = Base.populate_chained_attributes(
resource_name="App", operation_input_args=operation_input_args
)
logger.debug(f"Input request: {operation_input_args}")
# serialize the input request
operation_input_args = serialize(operation_input_args)
logger.debug(f"Serialized input request: {operation_input_args}")
# create the resource
response = client.create_app(**operation_input_args)
logger.debug(f"Response: {response}")
return cls.get(
domain_id=domain_id,
app_type=app_type,
app_name=app_name,
session=session,
region=region,
)
[docs]
@classmethod
@Base.add_validate_call
def get(
cls,
domain_id: str,
app_type: str,
app_name: str,
user_profile_name: Optional[str] = Unassigned(),
space_name: Optional[str] = Unassigned(),
session: Optional[Session] = None,
region: Optional[str] = None,
) -> Optional["App"]:
"""
Get a App resource
Parameters:
domain_id: The domain ID.
app_type: The type of app.
app_name: The name of the app.
user_profile_name: The user profile name. If this value is not set, then SpaceName must be set.
space_name: The name of the space.
session: Boto3 session.
region: Region name.
Returns:
The App resource.
Raises:
botocore.exceptions.ClientError: This exception is raised for AWS service related errors.
The error message and error code can be parsed from the exception as follows:
```
try:
# AWS service call here
except botocore.exceptions.ClientError as e:
error_message = e.response['Error']['Message']
error_code = e.response['Error']['Code']
```
ResourceNotFound: Resource being access is not found.
"""
operation_input_args = {
"DomainId": domain_id,
"UserProfileName": user_profile_name,
"SpaceName": space_name,
"AppType": app_type,
"AppName": app_name,
}
# serialize the input request
operation_input_args = serialize(operation_input_args)
logger.debug(f"Serialized input request: {operation_input_args}")
client = Base.get_sagemaker_client(
session=session, region_name=region, service_name="sagemaker"
)
response = client.describe_app(**operation_input_args)
logger.debug(response)
# deserialize the response
transformed_response = transform(response, "DescribeAppResponse")
app = cls(**transformed_response)
return app
[docs]
@Base.add_validate_call
def refresh(
self,
) -> Optional["App"]:
"""
Refresh a App resource
Returns:
The App resource.
Raises:
botocore.exceptions.ClientError: This exception is raised for AWS service related errors.
The error message and error code can be parsed from the exception as follows:
```
try:
# AWS service call here
except botocore.exceptions.ClientError as e:
error_message = e.response['Error']['Message']
error_code = e.response['Error']['Code']
```
ResourceNotFound: Resource being access is not found.
"""
operation_input_args = {
"DomainId": self.domain_id,
"UserProfileName": self.user_profile_name,
"SpaceName": self.space_name,
"AppType": self.app_type,
"AppName": self.app_name,
}
# serialize the input request
operation_input_args = serialize(operation_input_args)
logger.debug(f"Serialized input request: {operation_input_args}")
client = Base.get_sagemaker_client()
response = client.describe_app(**operation_input_args)
# deserialize response and update self
transform(response, "DescribeAppResponse", self)
return self
[docs]
@Base.add_validate_call
def delete(
self,
) -> None:
"""
Delete a App resource
Raises:
botocore.exceptions.ClientError: This exception is raised for AWS service related errors.
The error message and error code can be parsed from the exception as follows:
```
try:
# AWS service call here
except botocore.exceptions.ClientError as e:
error_message = e.response['Error']['Message']
error_code = e.response['Error']['Code']
```
ResourceInUse: Resource being accessed is in use.
ResourceNotFound: Resource being access is not found.
"""
client = Base.get_sagemaker_client()
operation_input_args = {
"DomainId": self.domain_id,
"UserProfileName": self.user_profile_name,
"SpaceName": self.space_name,
"AppType": self.app_type,
"AppName": self.app_name,
}
# serialize the input request
operation_input_args = serialize(operation_input_args)
logger.debug(f"Serialized input request: {operation_input_args}")
client.delete_app(**operation_input_args)
logger.info(f"Deleting {self.__class__.__name__} - {self.get_name()}")
[docs]
@Base.add_validate_call
def wait_for_status(
self,
target_status: Literal["Deleted", "Deleting", "Failed", "InService", "Pending"],
poll: int = 5,
timeout: Optional[int] = None,
) -> None:
"""
Wait for a App resource to reach certain status.
Parameters:
target_status: The status to wait for.
poll: The number of seconds to wait between each poll.
timeout: The maximum number of seconds to wait before timing out.
Raises:
TimeoutExceededError: If the resource does not reach a terminal state before the timeout.
FailedStatusError: If the resource reaches a failed state.
WaiterError: Raised when an error occurs while waiting.
"""
start_time = time.time()
progress = Progress(
SpinnerColumn("bouncingBar"),
TextColumn("{task.description}"),
TimeElapsedColumn(),
)
progress.add_task(f"Waiting for App to reach [bold]{target_status} status...")
status = Status("Current status:")
with Live(
Panel(
Group(progress, status),
title="Wait Log Panel",
border_style=Style(color=Color.BLUE.value),
),
transient=True,
):
while True:
self.refresh()
current_status = self.status
status.update(f"Current status: [bold]{current_status}")
if target_status == current_status:
logger.info(f"Final Resource Status: [bold]{current_status}")
return
if "failed" in current_status.lower():
raise FailedStatusError(
resource_type="App", status=current_status, reason=self.failure_reason
)
if timeout is not None and time.time() - start_time >= timeout:
raise TimeoutExceededError(resouce_type="App", status=current_status)
time.sleep(poll)
[docs]
@Base.add_validate_call
def wait_for_delete(
self,
poll: int = 5,
timeout: Optional[int] = None,
) -> None:
"""
Wait for a App resource to be deleted.
Parameters:
poll: The number of seconds to wait between each poll.
timeout: The maximum number of seconds to wait before timing out.
Raises:
botocore.exceptions.ClientError: This exception is raised for AWS service related errors.
The error message and error code can be parsed from the exception as follows:
```
try:
# AWS service call here
except botocore.exceptions.ClientError as e:
error_message = e.response['Error']['Message']
error_code = e.response['Error']['Code']
```
TimeoutExceededError: If the resource does not reach a terminal state before the timeout.
DeleteFailedStatusError: If the resource reaches a failed state.
WaiterError: Raised when an error occurs while waiting.
"""
start_time = time.time()
progress = Progress(
SpinnerColumn("bouncingBar"),
TextColumn("{task.description}"),
TimeElapsedColumn(),
)
progress.add_task("Waiting for App to be deleted...")
status = Status("Current status:")
with Live(
Panel(
Group(progress, status),
title="Wait Log Panel",
border_style=Style(color=Color.BLUE.value),
)
):
while True:
try:
self.refresh()
current_status = self.status
status.update(f"Current status: [bold]{current_status}")
if current_status.lower() == "deleted":
print("Resource was deleted.")
return
if timeout is not None and time.time() - start_time >= timeout:
raise TimeoutExceededError(resouce_type="App", status=current_status)
except botocore.exceptions.ClientError as e:
error_code = e.response["Error"]["Code"]
if "ResourceNotFound" in error_code or "ValidationException" in error_code:
logger.info("Resource was not found. It may have been deleted.")
return
raise e
time.sleep(poll)
[docs]
@classmethod
@Base.add_validate_call
def get_all(
cls,
sort_order: Optional[str] = Unassigned(),
sort_by: Optional[str] = Unassigned(),
domain_id_equals: Optional[str] = Unassigned(),
user_profile_name_equals: Optional[str] = Unassigned(),
space_name_equals: Optional[str] = Unassigned(),
session: Optional[Session] = None,
region: Optional[str] = None,
) -> ResourceIterator["App"]:
"""
Get all App resources
Parameters:
next_token: If the previous response was truncated, you will receive this token. Use it in your next request to receive the next set of results.
max_results: This parameter defines the maximum number of results that can be return in a single response. The MaxResults parameter is an upper bound, not a target. If there are more results available than the value specified, a NextToken is provided in the response. The NextToken indicates that the user should get the next set of results by providing this token as a part of a subsequent call. The default value for MaxResults is 10.
sort_order: The sort order for the results. The default is Ascending.
sort_by: The parameter by which to sort the results. The default is CreationTime.
domain_id_equals: A parameter to search for the domain ID.
user_profile_name_equals: A parameter to search by user profile name. If SpaceNameEquals is set, then this value cannot be set.
space_name_equals: A parameter to search by space name. If UserProfileNameEquals is set, then this value cannot be set.
session: Boto3 session.
region: Region name.
Returns:
Iterator for listed App resources.
Raises:
botocore.exceptions.ClientError: This exception is raised for AWS service related errors.
The error message and error code can be parsed from the exception as follows:
```
try:
# AWS service call here
except botocore.exceptions.ClientError as e:
error_message = e.response['Error']['Message']
error_code = e.response['Error']['Code']
```
"""
client = Base.get_sagemaker_client(
session=session, region_name=region, service_name="sagemaker"
)
operation_input_args = {
"SortOrder": sort_order,
"SortBy": sort_by,
"DomainIdEquals": domain_id_equals,
"UserProfileNameEquals": user_profile_name_equals,
"SpaceNameEquals": space_name_equals,
}
# serialize the input request
operation_input_args = serialize(operation_input_args)
logger.debug(f"Serialized input request: {operation_input_args}")
return ResourceIterator(
client=client,
list_method="list_apps",
summaries_key="Apps",
summary_name="AppDetails",
resource_cls=App,
list_method_kwargs=operation_input_args,
)
[docs]
class AppImageConfig(Base):
"""
Class representing resource AppImageConfig
Attributes:
app_image_config_arn: The ARN of the AppImageConfig.
app_image_config_name: The name of the AppImageConfig.
creation_time: When the AppImageConfig was created.
last_modified_time: When the AppImageConfig was last modified.
kernel_gateway_image_config: The configuration of a KernelGateway app.
jupyter_lab_app_image_config: The configuration of the JupyterLab app.
code_editor_app_image_config: The configuration of the Code Editor app.
"""
app_image_config_name: str
app_image_config_arn: Optional[str] = Unassigned()
creation_time: Optional[datetime.datetime] = Unassigned()
last_modified_time: Optional[datetime.datetime] = Unassigned()
kernel_gateway_image_config: Optional[KernelGatewayImageConfig] = Unassigned()
jupyter_lab_app_image_config: Optional[JupyterLabAppImageConfig] = Unassigned()
code_editor_app_image_config: Optional[CodeEditorAppImageConfig] = Unassigned()
def get_name(self) -> str:
attributes = vars(self)
resource_name = "app_image_config_name"
resource_name_split = resource_name.split("_")
attribute_name_candidates = []
l = len(resource_name_split)
for i in range(0, l):
attribute_name_candidates.append("_".join(resource_name_split[i:l]))
for attribute, value in attributes.items():
if attribute == "name" or attribute in attribute_name_candidates:
return value
logger.error("Name attribute not found for object app_image_config")
return None
[docs]
@classmethod
@Base.add_validate_call
def create(
cls,
app_image_config_name: str,
tags: Optional[List[Tag]] = Unassigned(),
kernel_gateway_image_config: Optional[KernelGatewayImageConfig] = Unassigned(),
jupyter_lab_app_image_config: Optional[JupyterLabAppImageConfig] = Unassigned(),
code_editor_app_image_config: Optional[CodeEditorAppImageConfig] = Unassigned(),
session: Optional[Session] = None,
region: Optional[str] = None,
) -> Optional["AppImageConfig"]:
"""
Create a AppImageConfig resource
Parameters:
app_image_config_name: The name of the AppImageConfig. Must be unique to your account.
tags: A list of tags to apply to the AppImageConfig.
kernel_gateway_image_config: The KernelGatewayImageConfig. You can only specify one image kernel in the AppImageConfig API. This kernel will be shown to users before the image starts. Once the image runs, all kernels are visible in JupyterLab.
jupyter_lab_app_image_config: The JupyterLabAppImageConfig. You can only specify one image kernel in the AppImageConfig API. This kernel is shown to users before the image starts. After the image runs, all kernels are visible in JupyterLab.
code_editor_app_image_config: The CodeEditorAppImageConfig. You can only specify one image kernel in the AppImageConfig API. This kernel is shown to users before the image starts. After the image runs, all kernels are visible in Code Editor.
session: Boto3 session.
region: Region name.
Returns:
The AppImageConfig resource.
Raises:
botocore.exceptions.ClientError: This exception is raised for AWS service related errors.
The error message and error code can be parsed from the exception as follows:
```
try:
# AWS service call here
except botocore.exceptions.ClientError as e:
error_message = e.response['Error']['Message']
error_code = e.response['Error']['Code']
```
ResourceInUse: Resource being accessed is in use.
ConfigSchemaValidationError: Raised when a configuration file does not adhere to the schema
LocalConfigNotFoundError: Raised when a configuration file is not found in local file system
S3ConfigNotFoundError: Raised when a configuration file is not found in S3
"""
logger.info("Creating app_image_config resource.")
client = Base.get_sagemaker_client(
session=session, region_name=region, service_name="sagemaker"
)
operation_input_args = {
"AppImageConfigName": app_image_config_name,
"Tags": tags,
"KernelGatewayImageConfig": kernel_gateway_image_config,
"JupyterLabAppImageConfig": jupyter_lab_app_image_config,
"CodeEditorAppImageConfig": code_editor_app_image_config,
}
operation_input_args = Base.populate_chained_attributes(
resource_name="AppImageConfig", operation_input_args=operation_input_args
)
logger.debug(f"Input request: {operation_input_args}")
# serialize the input request
operation_input_args = serialize(operation_input_args)
logger.debug(f"Serialized input request: {operation_input_args}")
# create the resource
response = client.create_app_image_config(**operation_input_args)
logger.debug(f"Response: {response}")
return cls.get(app_image_config_name=app_image_config_name, session=session, region=region)
[docs]
@classmethod
@Base.add_validate_call
def get(
cls,
app_image_config_name: str,
session: Optional[Session] = None,
region: Optional[str] = None,
) -> Optional["AppImageConfig"]:
"""
Get a AppImageConfig resource
Parameters:
app_image_config_name: The name of the AppImageConfig to describe.
session: Boto3 session.
region: Region name.
Returns:
The AppImageConfig resource.
Raises:
botocore.exceptions.ClientError: This exception is raised for AWS service related errors.
The error message and error code can be parsed from the exception as follows:
```
try:
# AWS service call here
except botocore.exceptions.ClientError as e:
error_message = e.response['Error']['Message']
error_code = e.response['Error']['Code']
```
ResourceNotFound: Resource being access is not found.
"""
operation_input_args = {
"AppImageConfigName": app_image_config_name,
}
# serialize the input request
operation_input_args = serialize(operation_input_args)
logger.debug(f"Serialized input request: {operation_input_args}")
client = Base.get_sagemaker_client(
session=session, region_name=region, service_name="sagemaker"
)
response = client.describe_app_image_config(**operation_input_args)
logger.debug(response)
# deserialize the response
transformed_response = transform(response, "DescribeAppImageConfigResponse")
app_image_config = cls(**transformed_response)
return app_image_config
[docs]
@Base.add_validate_call
def refresh(
self,
) -> Optional["AppImageConfig"]:
"""
Refresh a AppImageConfig resource
Returns:
The AppImageConfig resource.
Raises:
botocore.exceptions.ClientError: This exception is raised for AWS service related errors.
The error message and error code can be parsed from the exception as follows:
```
try:
# AWS service call here
except botocore.exceptions.ClientError as e:
error_message = e.response['Error']['Message']
error_code = e.response['Error']['Code']
```
ResourceNotFound: Resource being access is not found.
"""
operation_input_args = {
"AppImageConfigName": self.app_image_config_name,
}
# serialize the input request
operation_input_args = serialize(operation_input_args)
logger.debug(f"Serialized input request: {operation_input_args}")
client = Base.get_sagemaker_client()
response = client.describe_app_image_config(**operation_input_args)
# deserialize response and update self
transform(response, "DescribeAppImageConfigResponse", self)
return self
[docs]
@Base.add_validate_call
def update(
self,
kernel_gateway_image_config: Optional[KernelGatewayImageConfig] = Unassigned(),
jupyter_lab_app_image_config: Optional[JupyterLabAppImageConfig] = Unassigned(),
code_editor_app_image_config: Optional[CodeEditorAppImageConfig] = Unassigned(),
) -> Optional["AppImageConfig"]:
"""
Update a AppImageConfig resource
Returns:
The AppImageConfig resource.
Raises:
botocore.exceptions.ClientError: This exception is raised for AWS service related errors.
The error message and error code can be parsed from the exception as follows:
```
try:
# AWS service call here
except botocore.exceptions.ClientError as e:
error_message = e.response['Error']['Message']
error_code = e.response['Error']['Code']
```
ResourceNotFound: Resource being access is not found.
"""
logger.info("Updating app_image_config resource.")
client = Base.get_sagemaker_client()
operation_input_args = {
"AppImageConfigName": self.app_image_config_name,
"KernelGatewayImageConfig": kernel_gateway_image_config,
"JupyterLabAppImageConfig": jupyter_lab_app_image_config,
"CodeEditorAppImageConfig": code_editor_app_image_config,
}
logger.debug(f"Input request: {operation_input_args}")
# serialize the input request
operation_input_args = serialize(operation_input_args)
logger.debug(f"Serialized input request: {operation_input_args}")
# create the resource
response = client.update_app_image_config(**operation_input_args)
logger.debug(f"Response: {response}")
self.refresh()
return self
[docs]
@Base.add_validate_call
def delete(
self,
) -> None:
"""
Delete a AppImageConfig resource
Raises:
botocore.exceptions.ClientError: This exception is raised for AWS service related errors.
The error message and error code can be parsed from the exception as follows:
```
try:
# AWS service call here
except botocore.exceptions.ClientError as e:
error_message = e.response['Error']['Message']
error_code = e.response['Error']['Code']
```
ResourceNotFound: Resource being access is not found.
"""
client = Base.get_sagemaker_client()
operation_input_args = {
"AppImageConfigName": self.app_image_config_name,
}
# serialize the input request
operation_input_args = serialize(operation_input_args)
logger.debug(f"Serialized input request: {operation_input_args}")
client.delete_app_image_config(**operation_input_args)
logger.info(f"Deleting {self.__class__.__name__} - {self.get_name()}")
[docs]
@classmethod
@Base.add_validate_call
def get_all(
cls,
name_contains: Optional[str] = Unassigned(),
creation_time_before: Optional[datetime.datetime] = Unassigned(),
creation_time_after: Optional[datetime.datetime] = Unassigned(),
modified_time_before: Optional[datetime.datetime] = Unassigned(),
modified_time_after: Optional[datetime.datetime] = Unassigned(),
sort_by: Optional[str] = Unassigned(),
sort_order: Optional[str] = Unassigned(),
session: Optional[Session] = None,
region: Optional[str] = None,
) -> ResourceIterator["AppImageConfig"]:
"""
Get all AppImageConfig resources
Parameters:
max_results: The total number of items to return in the response. If the total number of items available is more than the value specified, a NextToken is provided in the response. To resume pagination, provide the NextToken value in the as part of a subsequent call. The default value is 10.
next_token: If the previous call to ListImages didn't return the full set of AppImageConfigs, the call returns a token for getting the next set of AppImageConfigs.
name_contains: A filter that returns only AppImageConfigs whose name contains the specified string.
creation_time_before: A filter that returns only AppImageConfigs created on or before the specified time.
creation_time_after: A filter that returns only AppImageConfigs created on or after the specified time.
modified_time_before: A filter that returns only AppImageConfigs modified on or before the specified time.
modified_time_after: A filter that returns only AppImageConfigs modified on or after the specified time.
sort_by: The property used to sort results. The default value is CreationTime.
sort_order: The sort order. The default value is Descending.
session: Boto3 session.
region: Region name.
Returns:
Iterator for listed AppImageConfig resources.
Raises:
botocore.exceptions.ClientError: This exception is raised for AWS service related errors.
The error message and error code can be parsed from the exception as follows:
```
try:
# AWS service call here
except botocore.exceptions.ClientError as e:
error_message = e.response['Error']['Message']
error_code = e.response['Error']['Code']
```
"""
client = Base.get_sagemaker_client(
session=session, region_name=region, service_name="sagemaker"
)
operation_input_args = {
"NameContains": name_contains,
"CreationTimeBefore": creation_time_before,
"CreationTimeAfter": creation_time_after,
"ModifiedTimeBefore": modified_time_before,
"ModifiedTimeAfter": modified_time_after,
"SortBy": sort_by,
"SortOrder": sort_order,
}
# serialize the input request
operation_input_args = serialize(operation_input_args)
logger.debug(f"Serialized input request: {operation_input_args}")
return ResourceIterator(
client=client,
list_method="list_app_image_configs",
summaries_key="AppImageConfigs",
summary_name="AppImageConfigDetails",
resource_cls=AppImageConfig,
list_method_kwargs=operation_input_args,
)
[docs]
class Artifact(Base):
"""
Class representing resource Artifact
Attributes:
artifact_name: The name of the artifact.
artifact_arn: The Amazon Resource Name (ARN) of the artifact.
source: The source of the artifact.
artifact_type: The type of the artifact.
properties: A list of the artifact's properties.
creation_time: When the artifact was created.
created_by:
last_modified_time: When the artifact was last modified.
last_modified_by:
metadata_properties:
lineage_group_arn: The Amazon Resource Name (ARN) of the lineage group.
"""
artifact_arn: str
artifact_name: Optional[str] = Unassigned()
source: Optional[ArtifactSource] = Unassigned()
artifact_type: Optional[str] = Unassigned()
properties: Optional[Dict[str, str]] = Unassigned()
creation_time: Optional[datetime.datetime] = Unassigned()
created_by: Optional[UserContext] = Unassigned()
last_modified_time: Optional[datetime.datetime] = Unassigned()
last_modified_by: Optional[UserContext] = Unassigned()
metadata_properties: Optional[MetadataProperties] = Unassigned()
lineage_group_arn: Optional[str] = Unassigned()
def get_name(self) -> str:
attributes = vars(self)
resource_name = "artifact_name"
resource_name_split = resource_name.split("_")
attribute_name_candidates = []
l = len(resource_name_split)
for i in range(0, l):
attribute_name_candidates.append("_".join(resource_name_split[i:l]))
for attribute, value in attributes.items():
if attribute == "name" or attribute in attribute_name_candidates:
return value
logger.error("Name attribute not found for object artifact")
return None
[docs]
@classmethod
@Base.add_validate_call
def create(
cls,
source: ArtifactSource,
artifact_type: str,
artifact_name: Optional[str] = Unassigned(),
properties: Optional[Dict[str, str]] = Unassigned(),
metadata_properties: Optional[MetadataProperties] = Unassigned(),
tags: Optional[List[Tag]] = Unassigned(),
session: Optional[Session] = None,
region: Optional[str] = None,
) -> Optional["Artifact"]:
"""
Create a Artifact resource
Parameters:
source: The ID, ID type, and URI of the source.
artifact_type: The artifact type.
artifact_name: The name of the artifact. Must be unique to your account in an Amazon Web Services Region.
properties: A list of properties to add to the artifact.
metadata_properties:
tags: A list of tags to apply to the artifact.
session: Boto3 session.
region: Region name.
Returns:
The Artifact resource.
Raises:
botocore.exceptions.ClientError: This exception is raised for AWS service related errors.
The error message and error code can be parsed from the exception as follows:
```
try:
# AWS service call here
except botocore.exceptions.ClientError as e:
error_message = e.response['Error']['Message']
error_code = e.response['Error']['Code']
```
ResourceLimitExceeded: You have exceeded an SageMaker resource limit. For example, you might have too many training jobs created.
ConfigSchemaValidationError: Raised when a configuration file does not adhere to the schema
LocalConfigNotFoundError: Raised when a configuration file is not found in local file system
S3ConfigNotFoundError: Raised when a configuration file is not found in S3
"""
logger.info("Creating artifact resource.")
client = Base.get_sagemaker_client(
session=session, region_name=region, service_name="sagemaker"
)
operation_input_args = {
"ArtifactName": artifact_name,
"Source": source,
"ArtifactType": artifact_type,
"Properties": properties,
"MetadataProperties": metadata_properties,
"Tags": tags,
}
operation_input_args = Base.populate_chained_attributes(
resource_name="Artifact", operation_input_args=operation_input_args
)
logger.debug(f"Input request: {operation_input_args}")
# serialize the input request
operation_input_args = serialize(operation_input_args)
logger.debug(f"Serialized input request: {operation_input_args}")
# create the resource
response = client.create_artifact(**operation_input_args)
logger.debug(f"Response: {response}")
return cls.get(artifact_arn=response["ArtifactArn"], session=session, region=region)
[docs]
@classmethod
@Base.add_validate_call
def get(
cls,
artifact_arn: str,
session: Optional[Session] = None,
region: Optional[str] = None,
) -> Optional["Artifact"]:
"""
Get a Artifact resource
Parameters:
artifact_arn: The Amazon Resource Name (ARN) of the artifact to describe.
session: Boto3 session.
region: Region name.
Returns:
The Artifact resource.
Raises:
botocore.exceptions.ClientError: This exception is raised for AWS service related errors.
The error message and error code can be parsed from the exception as follows:
```
try:
# AWS service call here
except botocore.exceptions.ClientError as e:
error_message = e.response['Error']['Message']
error_code = e.response['Error']['Code']
```
ResourceNotFound: Resource being access is not found.
"""
operation_input_args = {
"ArtifactArn": artifact_arn,
}
# serialize the input request
operation_input_args = serialize(operation_input_args)
logger.debug(f"Serialized input request: {operation_input_args}")
client = Base.get_sagemaker_client(
session=session, region_name=region, service_name="sagemaker"
)
response = client.describe_artifact(**operation_input_args)
logger.debug(response)
# deserialize the response
transformed_response = transform(response, "DescribeArtifactResponse")
artifact = cls(**transformed_response)
return artifact
[docs]
@Base.add_validate_call
def refresh(
self,
) -> Optional["Artifact"]:
"""
Refresh a Artifact resource
Returns:
The Artifact resource.
Raises:
botocore.exceptions.ClientError: This exception is raised for AWS service related errors.
The error message and error code can be parsed from the exception as follows:
```
try:
# AWS service call here
except botocore.exceptions.ClientError as e:
error_message = e.response['Error']['Message']
error_code = e.response['Error']['Code']
```
ResourceNotFound: Resource being access is not found.
"""
operation_input_args = {
"ArtifactArn": self.artifact_arn,
}
# serialize the input request
operation_input_args = serialize(operation_input_args)
logger.debug(f"Serialized input request: {operation_input_args}")
client = Base.get_sagemaker_client()
response = client.describe_artifact(**operation_input_args)
# deserialize response and update self
transform(response, "DescribeArtifactResponse", self)
return self
[docs]
@Base.add_validate_call
def update(
self,
artifact_name: Optional[str] = Unassigned(),
properties: Optional[Dict[str, str]] = Unassigned(),
properties_to_remove: Optional[List[str]] = Unassigned(),
) -> Optional["Artifact"]:
"""
Update a Artifact resource
Parameters:
properties_to_remove: A list of properties to remove.
Returns:
The Artifact resource.
Raises:
botocore.exceptions.ClientError: This exception is raised for AWS service related errors.
The error message and error code can be parsed from the exception as follows:
```
try:
# AWS service call here
except botocore.exceptions.ClientError as e:
error_message = e.response['Error']['Message']
error_code = e.response['Error']['Code']
```
ConflictException: There was a conflict when you attempted to modify a SageMaker entity such as an Experiment or Artifact.
ResourceNotFound: Resource being access is not found.
"""
logger.info("Updating artifact resource.")
client = Base.get_sagemaker_client()
operation_input_args = {
"ArtifactArn": self.artifact_arn,
"ArtifactName": artifact_name,
"Properties": properties,
"PropertiesToRemove": properties_to_remove,
}
logger.debug(f"Input request: {operation_input_args}")
# serialize the input request
operation_input_args = serialize(operation_input_args)
logger.debug(f"Serialized input request: {operation_input_args}")
# create the resource
response = client.update_artifact(**operation_input_args)
logger.debug(f"Response: {response}")
self.refresh()
return self
[docs]
@Base.add_validate_call
def delete(
self,
) -> None:
"""
Delete a Artifact resource
Raises:
botocore.exceptions.ClientError: This exception is raised for AWS service related errors.
The error message and error code can be parsed from the exception as follows:
```
try:
# AWS service call here
except botocore.exceptions.ClientError as e:
error_message = e.response['Error']['Message']
error_code = e.response['Error']['Code']
```
ResourceNotFound: Resource being access is not found.
"""
client = Base.get_sagemaker_client()
operation_input_args = {
"ArtifactArn": self.artifact_arn,
"Source": self.source,
}
# serialize the input request
operation_input_args = serialize(operation_input_args)
logger.debug(f"Serialized input request: {operation_input_args}")
client.delete_artifact(**operation_input_args)
logger.info(f"Deleting {self.__class__.__name__} - {self.get_name()}")
[docs]
@classmethod
@Base.add_validate_call
def get_all(
cls,
source_uri: Optional[str] = Unassigned(),
artifact_type: Optional[str] = Unassigned(),
created_after: Optional[datetime.datetime] = Unassigned(),
created_before: Optional[datetime.datetime] = Unassigned(),
sort_by: Optional[str] = Unassigned(),
sort_order: Optional[str] = Unassigned(),
session: Optional[Session] = None,
region: Optional[str] = None,
) -> ResourceIterator["Artifact"]:
"""
Get all Artifact resources
Parameters:
source_uri: A filter that returns only artifacts with the specified source URI.
artifact_type: A filter that returns only artifacts of the specified type.
created_after: A filter that returns only artifacts created on or after the specified time.
created_before: A filter that returns only artifacts created on or before the specified time.
sort_by: The property used to sort results. The default value is CreationTime.
sort_order: The sort order. The default value is Descending.
next_token: If the previous call to ListArtifacts didn't return the full set of artifacts, the call returns a token for getting the next set of artifacts.
max_results: The maximum number of artifacts to return in the response. The default value is 10.
session: Boto3 session.
region: Region name.
Returns:
Iterator for listed Artifact resources.
Raises:
botocore.exceptions.ClientError: This exception is raised for AWS service related errors.
The error message and error code can be parsed from the exception as follows:
```
try:
# AWS service call here
except botocore.exceptions.ClientError as e:
error_message = e.response['Error']['Message']
error_code = e.response['Error']['Code']
```
ResourceNotFound: Resource being access is not found.
"""
client = Base.get_sagemaker_client(
session=session, region_name=region, service_name="sagemaker"
)
operation_input_args = {
"SourceUri": source_uri,
"ArtifactType": artifact_type,
"CreatedAfter": created_after,
"CreatedBefore": created_before,
"SortBy": sort_by,
"SortOrder": sort_order,
}
# serialize the input request
operation_input_args = serialize(operation_input_args)
logger.debug(f"Serialized input request: {operation_input_args}")
return ResourceIterator(
client=client,
list_method="list_artifacts",
summaries_key="ArtifactSummaries",
summary_name="ArtifactSummary",
resource_cls=Artifact,
list_method_kwargs=operation_input_args,
)
[docs]
class Association(Base):
"""
Class representing resource Association
Attributes:
source_arn: The ARN of the source.
destination_arn: The Amazon Resource Name (ARN) of the destination.
source_type: The source type.
destination_type: The destination type.
association_type: The type of the association.
source_name: The name of the source.
destination_name: The name of the destination.
creation_time: When the association was created.
created_by:
"""
source_arn: Optional[str] = Unassigned()
destination_arn: Optional[str] = Unassigned()
source_type: Optional[str] = Unassigned()
destination_type: Optional[str] = Unassigned()
association_type: Optional[str] = Unassigned()
source_name: Optional[str] = Unassigned()
destination_name: Optional[str] = Unassigned()
creation_time: Optional[datetime.datetime] = Unassigned()
created_by: Optional[UserContext] = Unassigned()
def get_name(self) -> str:
attributes = vars(self)
resource_name = "association_name"
resource_name_split = resource_name.split("_")
attribute_name_candidates = []
l = len(resource_name_split)
for i in range(0, l):
attribute_name_candidates.append("_".join(resource_name_split[i:l]))
for attribute, value in attributes.items():
if attribute == "name" or attribute in attribute_name_candidates:
return value
logger.error("Name attribute not found for object association")
return None
[docs]
@Base.add_validate_call
def delete(
self,
) -> None:
"""
Delete a Association resource
Raises:
botocore.exceptions.ClientError: This exception is raised for AWS service related errors.
The error message and error code can be parsed from the exception as follows:
```
try:
# AWS service call here
except botocore.exceptions.ClientError as e:
error_message = e.response['Error']['Message']
error_code = e.response['Error']['Code']
```
ResourceNotFound: Resource being access is not found.
"""
client = Base.get_sagemaker_client()
operation_input_args = {
"SourceArn": self.source_arn,
"DestinationArn": self.destination_arn,
}
# serialize the input request
operation_input_args = serialize(operation_input_args)
logger.debug(f"Serialized input request: {operation_input_args}")
client.delete_association(**operation_input_args)
logger.info(f"Deleting {self.__class__.__name__} - {self.get_name()}")
[docs]
@classmethod
@Base.add_validate_call
def get_all(
cls,
source_arn: Optional[str] = Unassigned(),
destination_arn: Optional[str] = Unassigned(),
source_type: Optional[str] = Unassigned(),
destination_type: Optional[str] = Unassigned(),
association_type: Optional[str] = Unassigned(),
created_after: Optional[datetime.datetime] = Unassigned(),
created_before: Optional[datetime.datetime] = Unassigned(),
sort_by: Optional[str] = Unassigned(),
sort_order: Optional[str] = Unassigned(),
session: Optional[Session] = None,
region: Optional[str] = None,
) -> ResourceIterator["Association"]:
"""
Get all Association resources
Parameters:
source_arn: A filter that returns only associations with the specified source ARN.
destination_arn: A filter that returns only associations with the specified destination Amazon Resource Name (ARN).
source_type: A filter that returns only associations with the specified source type.
destination_type: A filter that returns only associations with the specified destination type.
association_type: A filter that returns only associations of the specified type.
created_after: A filter that returns only associations created on or after the specified time.
created_before: A filter that returns only associations created on or before the specified time.
sort_by: The property used to sort results. The default value is CreationTime.
sort_order: The sort order. The default value is Descending.
next_token: If the previous call to ListAssociations didn't return the full set of associations, the call returns a token for getting the next set of associations.
max_results: The maximum number of associations to return in the response. The default value is 10.
session: Boto3 session.
region: Region name.
Returns:
Iterator for listed Association resources.
Raises:
botocore.exceptions.ClientError: This exception is raised for AWS service related errors.
The error message and error code can be parsed from the exception as follows:
```
try:
# AWS service call here
except botocore.exceptions.ClientError as e:
error_message = e.response['Error']['Message']
error_code = e.response['Error']['Code']
```
ResourceNotFound: Resource being access is not found.
"""
client = Base.get_sagemaker_client(
session=session, region_name=region, service_name="sagemaker"
)
operation_input_args = {
"SourceArn": source_arn,
"DestinationArn": destination_arn,
"SourceType": source_type,
"DestinationType": destination_type,
"AssociationType": association_type,
"CreatedAfter": created_after,
"CreatedBefore": created_before,
"SortBy": sort_by,
"SortOrder": sort_order,
}
# serialize the input request
operation_input_args = serialize(operation_input_args)
logger.debug(f"Serialized input request: {operation_input_args}")
return ResourceIterator(
client=client,
list_method="list_associations",
summaries_key="AssociationSummaries",
summary_name="AssociationSummary",
resource_cls=Association,
list_method_kwargs=operation_input_args,
)
[docs]
@classmethod
@Base.add_validate_call
def add(
cls,
source_arn: str,
destination_arn: str,
association_type: Optional[str] = Unassigned(),
session: Optional[Session] = None,
region: Optional[str] = None,
) -> None:
"""
Creates an association between the source and the destination.
Parameters:
source_arn: The ARN of the source.
destination_arn: The Amazon Resource Name (ARN) of the destination.
association_type: The type of association. The following are suggested uses for each type. Amazon SageMaker places no restrictions on their use. ContributedTo - The source contributed to the destination or had a part in enabling the destination. For example, the training data contributed to the training job. AssociatedWith - The source is connected to the destination. For example, an approval workflow is associated with a model deployment. DerivedFrom - The destination is a modification of the source. For example, a digest output of a channel input for a processing job is derived from the original inputs. Produced - The source generated the destination. For example, a training job produced a model artifact.
session: Boto3 session.
region: Region name.
Raises:
botocore.exceptions.ClientError: This exception is raised for AWS service related errors.
The error message and error code can be parsed from the exception as follows:
```
try:
# AWS service call here
except botocore.exceptions.ClientError as e:
error_message = e.response['Error']['Message']
error_code = e.response['Error']['Code']
```
ResourceLimitExceeded: You have exceeded an SageMaker resource limit. For example, you might have too many training jobs created.
ResourceNotFound: Resource being access is not found.
"""
operation_input_args = {
"SourceArn": source_arn,
"DestinationArn": destination_arn,
"AssociationType": association_type,
}
# serialize the input request
operation_input_args = serialize(operation_input_args)
logger.debug(f"Serialized input request: {operation_input_args}")
client = Base.get_sagemaker_client(
session=session, region_name=region, service_name="sagemaker"
)
logger.debug(f"Calling add_association API")
response = client.add_association(**operation_input_args)
logger.debug(f"Response: {response}")
[docs]
class AutoMLJob(Base):
"""
Class representing resource AutoMLJob
Attributes:
auto_ml_job_name: Returns the name of the AutoML job.
auto_ml_job_arn: Returns the ARN of the AutoML job.
input_data_config: Returns the input data configuration for the AutoML job.
output_data_config: Returns the job's output data config.
role_arn: The ARN of the IAM role that has read permission to the input data location and write permission to the output data location in Amazon S3.
creation_time: Returns the creation time of the AutoML job.
last_modified_time: Returns the job's last modified time.
auto_ml_job_status: Returns the status of the AutoML job.
auto_ml_job_secondary_status: Returns the secondary status of the AutoML job.
auto_ml_job_objective: Returns the job's objective.
problem_type: Returns the job's problem type.
auto_ml_job_config: Returns the configuration for the AutoML job.
end_time: Returns the end time of the AutoML job.
failure_reason: Returns the failure reason for an AutoML job, when applicable.
partial_failure_reasons: Returns a list of reasons for partial failures within an AutoML job.
best_candidate: The best model candidate selected by SageMaker Autopilot using both the best objective metric and lowest InferenceLatency for an experiment.
generate_candidate_definitions_only: Indicates whether the output for an AutoML job generates candidate definitions only.
auto_ml_job_artifacts: Returns information on the job's artifacts found in AutoMLJobArtifacts.
resolved_attributes: Contains ProblemType, AutoMLJobObjective, and CompletionCriteria. If you do not provide these values, they are inferred.
model_deploy_config: Indicates whether the model was deployed automatically to an endpoint and the name of that endpoint if deployed automatically.
model_deploy_result: Provides information about endpoint for the model deployment.
"""
auto_ml_job_name: str
auto_ml_job_arn: Optional[str] = Unassigned()
input_data_config: Optional[List[AutoMLChannel]] = Unassigned()
output_data_config: Optional[AutoMLOutputDataConfig] = Unassigned()
role_arn: Optional[str] = Unassigned()
auto_ml_job_objective: Optional[AutoMLJobObjective] = Unassigned()
problem_type: Optional[str] = Unassigned()
auto_ml_job_config: Optional[AutoMLJobConfig] = Unassigned()
creation_time: Optional[datetime.datetime] = Unassigned()
end_time: Optional[datetime.datetime] = Unassigned()
last_modified_time: Optional[datetime.datetime] = Unassigned()
failure_reason: Optional[str] = Unassigned()
partial_failure_reasons: Optional[List[AutoMLPartialFailureReason]] = Unassigned()
best_candidate: Optional[AutoMLCandidate] = Unassigned()
auto_ml_job_status: Optional[str] = Unassigned()
auto_ml_job_secondary_status: Optional[str] = Unassigned()
generate_candidate_definitions_only: Optional[bool] = Unassigned()
auto_ml_job_artifacts: Optional[AutoMLJobArtifacts] = Unassigned()
resolved_attributes: Optional[ResolvedAttributes] = Unassigned()
model_deploy_config: Optional[ModelDeployConfig] = Unassigned()
model_deploy_result: Optional[ModelDeployResult] = Unassigned()
def get_name(self) -> str:
attributes = vars(self)
resource_name = "auto_ml_job_name"
resource_name_split = resource_name.split("_")
attribute_name_candidates = []
l = len(resource_name_split)
for i in range(0, l):
attribute_name_candidates.append("_".join(resource_name_split[i:l]))
for attribute, value in attributes.items():
if attribute == "name" or attribute in attribute_name_candidates:
return value
logger.error("Name attribute not found for object auto_ml_job")
return None
def populate_inputs_decorator(create_func):
@functools.wraps(create_func)
def wrapper(*args, **kwargs):
config_schema_for_resource = {
"output_data_config": {
"s3_output_path": {"type": "string"},
"kms_key_id": {"type": "string"},
},
"role_arn": {"type": "string"},
"auto_ml_job_config": {
"security_config": {
"volume_kms_key_id": {"type": "string"},
"vpc_config": {
"security_group_ids": {"type": "array", "items": {"type": "string"}},
"subnets": {"type": "array", "items": {"type": "string"}},
},
},
"candidate_generation_config": {
"feature_specification_s3_uri": {"type": "string"}
},
},
}
return create_func(
*args,
**Base.get_updated_kwargs_with_configured_attributes(
config_schema_for_resource, "AutoMLJob", **kwargs
),
)
return wrapper
[docs]
@classmethod
@populate_inputs_decorator
@Base.add_validate_call
def create(
cls,
auto_ml_job_name: str,
input_data_config: List[AutoMLChannel],
output_data_config: AutoMLOutputDataConfig,
role_arn: str,
problem_type: Optional[str] = Unassigned(),
auto_ml_job_objective: Optional[AutoMLJobObjective] = Unassigned(),
auto_ml_job_config: Optional[AutoMLJobConfig] = Unassigned(),
generate_candidate_definitions_only: Optional[bool] = Unassigned(),
tags: Optional[List[Tag]] = Unassigned(),
model_deploy_config: Optional[ModelDeployConfig] = Unassigned(),
session: Optional[Session] = None,
region: Optional[str] = None,
) -> Optional["AutoMLJob"]:
"""
Create a AutoMLJob resource
Parameters:
auto_ml_job_name: Identifies an Autopilot job. The name must be unique to your account and is case insensitive.
input_data_config: An array of channel objects that describes the input data and its location. Each channel is a named input source. Similar to InputDataConfig supported by HyperParameterTrainingJobDefinition. Format(s) supported: CSV, Parquet. A minimum of 500 rows is required for the training dataset. There is not a minimum number of rows required for the validation dataset.
output_data_config: Provides information about encryption and the Amazon S3 output path needed to store artifacts from an AutoML job. Format(s) supported: CSV.
role_arn: The ARN of the role that is used to access the data.
problem_type: Defines the type of supervised learning problem available for the candidates. For more information, see SageMaker Autopilot problem types.
auto_ml_job_objective: Specifies a metric to minimize or maximize as the objective of a job. If not specified, the default objective metric depends on the problem type. See AutoMLJobObjective for the default values.
auto_ml_job_config: A collection of settings used to configure an AutoML job.
generate_candidate_definitions_only: Generates possible candidates without training the models. A candidate is a combination of data preprocessors, algorithms, and algorithm parameter settings.
tags: An array of key-value pairs. You can use tags to categorize your Amazon Web Services resources in different ways, for example, by purpose, owner, or environment. For more information, see Tagging Amazon Web ServicesResources. Tag keys must be unique per resource.
model_deploy_config: Specifies how to generate the endpoint name for an automatic one-click Autopilot model deployment.
session: Boto3 session.
region: Region name.
Returns:
The AutoMLJob resource.
Raises:
botocore.exceptions.ClientError: This exception is raised for AWS service related errors.
The error message and error code can be parsed from the exception as follows:
```
try:
# AWS service call here
except botocore.exceptions.ClientError as e:
error_message = e.response['Error']['Message']
error_code = e.response['Error']['Code']
```
ResourceInUse: Resource being accessed is in use.
ResourceLimitExceeded: You have exceeded an SageMaker resource limit. For example, you might have too many training jobs created.
ConfigSchemaValidationError: Raised when a configuration file does not adhere to the schema
LocalConfigNotFoundError: Raised when a configuration file is not found in local file system
S3ConfigNotFoundError: Raised when a configuration file is not found in S3
"""
logger.info("Creating auto_ml_job resource.")
client = Base.get_sagemaker_client(
session=session, region_name=region, service_name="sagemaker"
)
operation_input_args = {
"AutoMLJobName": auto_ml_job_name,
"InputDataConfig": input_data_config,
"OutputDataConfig": output_data_config,
"ProblemType": problem_type,
"AutoMLJobObjective": auto_ml_job_objective,
"AutoMLJobConfig": auto_ml_job_config,
"RoleArn": role_arn,
"GenerateCandidateDefinitionsOnly": generate_candidate_definitions_only,
"Tags": tags,
"ModelDeployConfig": model_deploy_config,
}
operation_input_args = Base.populate_chained_attributes(
resource_name="AutoMLJob", operation_input_args=operation_input_args
)
logger.debug(f"Input request: {operation_input_args}")
# serialize the input request
operation_input_args = serialize(operation_input_args)
logger.debug(f"Serialized input request: {operation_input_args}")
# create the resource
response = client.create_auto_ml_job(**operation_input_args)
logger.debug(f"Response: {response}")
return cls.get(auto_ml_job_name=auto_ml_job_name, session=session, region=region)
[docs]
@classmethod
@Base.add_validate_call
def get(
cls,
auto_ml_job_name: str,
session: Optional[Session] = None,
region: Optional[str] = None,
) -> Optional["AutoMLJob"]:
"""
Get a AutoMLJob resource
Parameters:
auto_ml_job_name: Requests information about an AutoML job using its unique name.
session: Boto3 session.
region: Region name.
Returns:
The AutoMLJob resource.
Raises:
botocore.exceptions.ClientError: This exception is raised for AWS service related errors.
The error message and error code can be parsed from the exception as follows:
```
try:
# AWS service call here
except botocore.exceptions.ClientError as e:
error_message = e.response['Error']['Message']
error_code = e.response['Error']['Code']
```
ResourceNotFound: Resource being access is not found.
"""
operation_input_args = {
"AutoMLJobName": auto_ml_job_name,
}
# serialize the input request
operation_input_args = serialize(operation_input_args)
logger.debug(f"Serialized input request: {operation_input_args}")
client = Base.get_sagemaker_client(
session=session, region_name=region, service_name="sagemaker"
)
response = client.describe_auto_ml_job(**operation_input_args)
logger.debug(response)
# deserialize the response
transformed_response = transform(response, "DescribeAutoMLJobResponse")
auto_ml_job = cls(**transformed_response)
return auto_ml_job
[docs]
@Base.add_validate_call
def refresh(
self,
) -> Optional["AutoMLJob"]:
"""
Refresh a AutoMLJob resource
Returns:
The AutoMLJob resource.
Raises:
botocore.exceptions.ClientError: This exception is raised for AWS service related errors.
The error message and error code can be parsed from the exception as follows:
```
try:
# AWS service call here
except botocore.exceptions.ClientError as e:
error_message = e.response['Error']['Message']
error_code = e.response['Error']['Code']
```
ResourceNotFound: Resource being access is not found.
"""
operation_input_args = {
"AutoMLJobName": self.auto_ml_job_name,
}
# serialize the input request
operation_input_args = serialize(operation_input_args)
logger.debug(f"Serialized input request: {operation_input_args}")
client = Base.get_sagemaker_client()
response = client.describe_auto_ml_job(**operation_input_args)
# deserialize response and update self
transform(response, "DescribeAutoMLJobResponse", self)
return self
[docs]
@Base.add_validate_call
def stop(self) -> None:
"""
Stop a AutoMLJob resource
Raises:
botocore.exceptions.ClientError: This exception is raised for AWS service related errors.
The error message and error code can be parsed from the exception as follows:
```
try:
# AWS service call here
except botocore.exceptions.ClientError as e:
error_message = e.response['Error']['Message']
error_code = e.response['Error']['Code']
```
ResourceNotFound: Resource being access is not found.
"""
client = SageMakerClient().client
operation_input_args = {
"AutoMLJobName": self.auto_ml_job_name,
}
# serialize the input request
operation_input_args = serialize(operation_input_args)
logger.debug(f"Serialized input request: {operation_input_args}")
client.stop_auto_ml_job(**operation_input_args)
logger.info(f"Stopping {self.__class__.__name__} - {self.get_name()}")
[docs]
@Base.add_validate_call
def wait(
self,
poll: int = 5,
timeout: Optional[int] = None,
) -> None:
"""
Wait for a AutoMLJob resource.
Parameters:
poll: The number of seconds to wait between each poll.
timeout: The maximum number of seconds to wait before timing out.
Raises:
TimeoutExceededError: If the resource does not reach a terminal state before the timeout.
FailedStatusError: If the resource reaches a failed state.
WaiterError: Raised when an error occurs while waiting.
"""
terminal_states = ["Completed", "Failed", "Stopped"]
start_time = time.time()
progress = Progress(
SpinnerColumn("bouncingBar"),
TextColumn("{task.description}"),
TimeElapsedColumn(),
)
progress.add_task("Waiting for AutoMLJob...")
status = Status("Current status:")
with Live(
Panel(
Group(progress, status),
title="Wait Log Panel",
border_style=Style(color=Color.BLUE.value),
),
transient=True,
):
while True:
self.refresh()
current_status = self.auto_ml_job_status
status.update(f"Current status: [bold]{current_status}")
if current_status in terminal_states:
logger.info(f"Final Resource Status: [bold]{current_status}")
if "failed" in current_status.lower():
raise FailedStatusError(
resource_type="AutoMLJob",
status=current_status,
reason=self.failure_reason,
)
return
if timeout is not None and time.time() - start_time >= timeout:
raise TimeoutExceededError(resouce_type="AutoMLJob", status=current_status)
time.sleep(poll)
[docs]
@classmethod
@Base.add_validate_call
def get_all(
cls,
creation_time_after: Optional[datetime.datetime] = Unassigned(),
creation_time_before: Optional[datetime.datetime] = Unassigned(),
last_modified_time_after: Optional[datetime.datetime] = Unassigned(),
last_modified_time_before: Optional[datetime.datetime] = Unassigned(),
name_contains: Optional[str] = Unassigned(),
status_equals: Optional[str] = Unassigned(),
sort_order: Optional[str] = Unassigned(),
sort_by: Optional[str] = Unassigned(),
session: Optional[Session] = None,
region: Optional[str] = None,
) -> ResourceIterator["AutoMLJob"]:
"""
Get all AutoMLJob resources
Parameters:
creation_time_after: Request a list of jobs, using a filter for time.
creation_time_before: Request a list of jobs, using a filter for time.
last_modified_time_after: Request a list of jobs, using a filter for time.
last_modified_time_before: Request a list of jobs, using a filter for time.
name_contains: Request a list of jobs, using a search filter for name.
status_equals: Request a list of jobs, using a filter for status.
sort_order: The sort order for the results. The default is Descending.
sort_by: The parameter by which to sort the results. The default is Name.
max_results: Request a list of jobs up to a specified limit.
next_token: If the previous response was truncated, you receive this token. Use it in your next request to receive the next set of results.
session: Boto3 session.
region: Region name.
Returns:
Iterator for listed AutoMLJob resources.
Raises:
botocore.exceptions.ClientError: This exception is raised for AWS service related errors.
The error message and error code can be parsed from the exception as follows:
```
try:
# AWS service call here
except botocore.exceptions.ClientError as e:
error_message = e.response['Error']['Message']
error_code = e.response['Error']['Code']
```
"""
client = Base.get_sagemaker_client(
session=session, region_name=region, service_name="sagemaker"
)
operation_input_args = {
"CreationTimeAfter": creation_time_after,
"CreationTimeBefore": creation_time_before,
"LastModifiedTimeAfter": last_modified_time_after,
"LastModifiedTimeBefore": last_modified_time_before,
"NameContains": name_contains,
"StatusEquals": status_equals,
"SortOrder": sort_order,
"SortBy": sort_by,
}
# serialize the input request
operation_input_args = serialize(operation_input_args)
logger.debug(f"Serialized input request: {operation_input_args}")
return ResourceIterator(
client=client,
list_method="list_auto_ml_jobs",
summaries_key="AutoMLJobSummaries",
summary_name="AutoMLJobSummary",
resource_cls=AutoMLJob,
list_method_kwargs=operation_input_args,
)
[docs]
@Base.add_validate_call
def get_all_candidates(
self,
status_equals: Optional[str] = Unassigned(),
candidate_name_equals: Optional[str] = Unassigned(),
sort_order: Optional[str] = Unassigned(),
sort_by: Optional[str] = Unassigned(),
session: Optional[Session] = None,
region: Optional[str] = None,
) -> ResourceIterator[AutoMLCandidate]:
"""
List the candidates created for the job.
Parameters:
status_equals: List the candidates for the job and filter by status.
candidate_name_equals: List the candidates for the job and filter by candidate name.
sort_order: The sort order for the results. The default is Ascending.
sort_by: The parameter by which to sort the results. The default is Descending.
max_results: List the job's candidates up to a specified limit.
next_token: If the previous response was truncated, you receive this token. Use it in your next request to receive the next set of results.
session: Boto3 session.
region: Region name.
Returns:
Iterator for listed AutoMLCandidate.
Raises:
botocore.exceptions.ClientError: This exception is raised for AWS service related errors.
The error message and error code can be parsed from the exception as follows:
```
try:
# AWS service call here
except botocore.exceptions.ClientError as e:
error_message = e.response['Error']['Message']
error_code = e.response['Error']['Code']
```
ResourceNotFound: Resource being access is not found.
"""
operation_input_args = {
"AutoMLJobName": self.auto_ml_job_name,
"StatusEquals": status_equals,
"CandidateNameEquals": candidate_name_equals,
"SortOrder": sort_order,
"SortBy": sort_by,
}
# serialize the input request
operation_input_args = serialize(operation_input_args)
logger.debug(f"Serialized input request: {operation_input_args}")
client = Base.get_sagemaker_client(
session=session, region_name=region, service_name="sagemaker"
)
return ResourceIterator(
client=client,
list_method="list_candidates_for_auto_ml_job",
summaries_key="Candidates",
summary_name="AutoMLCandidate",
resource_cls=AutoMLCandidate,
list_method_kwargs=operation_input_args,
)
[docs]
class AutoMLJobV2(Base):
"""
Class representing resource AutoMLJobV2
Attributes:
auto_ml_job_name: Returns the name of the AutoML job V2.
auto_ml_job_arn: Returns the Amazon Resource Name (ARN) of the AutoML job V2.
auto_ml_job_input_data_config: Returns an array of channel objects describing the input data and their location.
output_data_config: Returns the job's output data config.
role_arn: The ARN of the IAM role that has read permission to the input data location and write permission to the output data location in Amazon S3.
creation_time: Returns the creation time of the AutoML job V2.
last_modified_time: Returns the job's last modified time.
auto_ml_job_status: Returns the status of the AutoML job V2.
auto_ml_job_secondary_status: Returns the secondary status of the AutoML job V2.
auto_ml_job_objective: Returns the job's objective.
auto_ml_problem_type_config: Returns the configuration settings of the problem type set for the AutoML job V2.
auto_ml_problem_type_config_name: Returns the name of the problem type configuration set for the AutoML job V2.
end_time: Returns the end time of the AutoML job V2.
failure_reason: Returns the reason for the failure of the AutoML job V2, when applicable.
partial_failure_reasons: Returns a list of reasons for partial failures within an AutoML job V2.
best_candidate: Information about the candidate produced by an AutoML training job V2, including its status, steps, and other properties.
auto_ml_job_artifacts:
resolved_attributes: Returns the resolved attributes used by the AutoML job V2.
model_deploy_config: Indicates whether the model was deployed automatically to an endpoint and the name of that endpoint if deployed automatically.
model_deploy_result: Provides information about endpoint for the model deployment.
data_split_config: Returns the configuration settings of how the data are split into train and validation datasets.
security_config: Returns the security configuration for traffic encryption or Amazon VPC settings.
auto_ml_compute_config: The compute configuration used for the AutoML job V2.
"""
auto_ml_job_name: str
auto_ml_job_arn: Optional[str] = Unassigned()
auto_ml_job_input_data_config: Optional[List[AutoMLJobChannel]] = Unassigned()
output_data_config: Optional[AutoMLOutputDataConfig] = Unassigned()
role_arn: Optional[str] = Unassigned()
auto_ml_job_objective: Optional[AutoMLJobObjective] = Unassigned()
auto_ml_problem_type_config: Optional[AutoMLProblemTypeConfig] = Unassigned()
auto_ml_problem_type_config_name: Optional[str] = Unassigned()
creation_time: Optional[datetime.datetime] = Unassigned()
end_time: Optional[datetime.datetime] = Unassigned()
last_modified_time: Optional[datetime.datetime] = Unassigned()
failure_reason: Optional[str] = Unassigned()
partial_failure_reasons: Optional[List[AutoMLPartialFailureReason]] = Unassigned()
best_candidate: Optional[AutoMLCandidate] = Unassigned()
auto_ml_job_status: Optional[str] = Unassigned()
auto_ml_job_secondary_status: Optional[str] = Unassigned()
auto_ml_job_artifacts: Optional[AutoMLJobArtifacts] = Unassigned()
resolved_attributes: Optional[AutoMLResolvedAttributes] = Unassigned()
model_deploy_config: Optional[ModelDeployConfig] = Unassigned()
model_deploy_result: Optional[ModelDeployResult] = Unassigned()
data_split_config: Optional[AutoMLDataSplitConfig] = Unassigned()
security_config: Optional[AutoMLSecurityConfig] = Unassigned()
auto_ml_compute_config: Optional[AutoMLComputeConfig] = Unassigned()
def get_name(self) -> str:
attributes = vars(self)
resource_name = "auto_ml_job_v2_name"
resource_name_split = resource_name.split("_")
attribute_name_candidates = []
l = len(resource_name_split)
for i in range(0, l):
attribute_name_candidates.append("_".join(resource_name_split[i:l]))
for attribute, value in attributes.items():
if attribute == "name" or attribute in attribute_name_candidates:
return value
logger.error("Name attribute not found for object auto_ml_job_v2")
return None
def populate_inputs_decorator(create_func):
@functools.wraps(create_func)
def wrapper(*args, **kwargs):
config_schema_for_resource = {
"output_data_config": {
"s3_output_path": {"type": "string"},
"kms_key_id": {"type": "string"},
},
"role_arn": {"type": "string"},
"auto_ml_problem_type_config": {
"time_series_forecasting_job_config": {
"feature_specification_s3_uri": {"type": "string"}
},
"tabular_job_config": {"feature_specification_s3_uri": {"type": "string"}},
},
"security_config": {
"volume_kms_key_id": {"type": "string"},
"vpc_config": {
"security_group_ids": {"type": "array", "items": {"type": "string"}},
"subnets": {"type": "array", "items": {"type": "string"}},
},
},
"auto_ml_compute_config": {
"emr_serverless_compute_config": {"execution_role_arn": {"type": "string"}}
},
}
return create_func(
*args,
**Base.get_updated_kwargs_with_configured_attributes(
config_schema_for_resource, "AutoMLJobV2", **kwargs
),
)
return wrapper
[docs]
@classmethod
@populate_inputs_decorator
@Base.add_validate_call
def create(
cls,
auto_ml_job_name: str,
auto_ml_job_input_data_config: List[AutoMLJobChannel],
output_data_config: AutoMLOutputDataConfig,
auto_ml_problem_type_config: AutoMLProblemTypeConfig,
role_arn: str,
tags: Optional[List[Tag]] = Unassigned(),
security_config: Optional[AutoMLSecurityConfig] = Unassigned(),
auto_ml_job_objective: Optional[AutoMLJobObjective] = Unassigned(),
model_deploy_config: Optional[ModelDeployConfig] = Unassigned(),
data_split_config: Optional[AutoMLDataSplitConfig] = Unassigned(),
auto_ml_compute_config: Optional[AutoMLComputeConfig] = Unassigned(),
session: Optional[Session] = None,
region: Optional[str] = None,
) -> Optional["AutoMLJobV2"]:
"""
Create a AutoMLJobV2 resource
Parameters:
auto_ml_job_name: Identifies an Autopilot job. The name must be unique to your account and is case insensitive.
auto_ml_job_input_data_config: An array of channel objects describing the input data and their location. Each channel is a named input source. Similar to the InputDataConfig attribute in the CreateAutoMLJob input parameters. The supported formats depend on the problem type: For tabular problem types: S3Prefix, ManifestFile. For image classification: S3Prefix, ManifestFile, AugmentedManifestFile. For text classification: S3Prefix. For time-series forecasting: S3Prefix. For text generation (LLMs fine-tuning): S3Prefix.
output_data_config: Provides information about encryption and the Amazon S3 output path needed to store artifacts from an AutoML job.
auto_ml_problem_type_config: Defines the configuration settings of one of the supported problem types.
role_arn: The ARN of the role that is used to access the data.
tags: An array of key-value pairs. You can use tags to categorize your Amazon Web Services resources in different ways, such as by purpose, owner, or environment. For more information, see Tagging Amazon Web ServicesResources. Tag keys must be unique per resource.
security_config: The security configuration for traffic encryption or Amazon VPC settings.
auto_ml_job_objective: Specifies a metric to minimize or maximize as the objective of a job. If not specified, the default objective metric depends on the problem type. For the list of default values per problem type, see AutoMLJobObjective. For tabular problem types: You must either provide both the AutoMLJobObjective and indicate the type of supervised learning problem in AutoMLProblemTypeConfig (TabularJobConfig.ProblemType), or none at all. For text generation problem types (LLMs fine-tuning): Fine-tuning language models in Autopilot does not require setting the AutoMLJobObjective field. Autopilot fine-tunes LLMs without requiring multiple candidates to be trained and evaluated. Instead, using your dataset, Autopilot directly fine-tunes your target model to enhance a default objective metric, the cross-entropy loss. After fine-tuning a language model, you can evaluate the quality of its generated text using different metrics. For a list of the available metrics, see Metrics for fine-tuning LLMs in Autopilot.
model_deploy_config: Specifies how to generate the endpoint name for an automatic one-click Autopilot model deployment.
data_split_config: This structure specifies how to split the data into train and validation datasets. The validation and training datasets must contain the same headers. For jobs created by calling CreateAutoMLJob, the validation dataset must be less than 2 GB in size. This attribute must not be set for the time-series forecasting problem type, as Autopilot automatically splits the input dataset into training and validation sets.
auto_ml_compute_config: Specifies the compute configuration for the AutoML job V2.
session: Boto3 session.
region: Region name.
Returns:
The AutoMLJobV2 resource.
Raises:
botocore.exceptions.ClientError: This exception is raised for AWS service related errors.
The error message and error code can be parsed from the exception as follows:
```
try:
# AWS service call here
except botocore.exceptions.ClientError as e:
error_message = e.response['Error']['Message']
error_code = e.response['Error']['Code']
```
ResourceInUse: Resource being accessed is in use.
ResourceLimitExceeded: You have exceeded an SageMaker resource limit. For example, you might have too many training jobs created.
ConfigSchemaValidationError: Raised when a configuration file does not adhere to the schema
LocalConfigNotFoundError: Raised when a configuration file is not found in local file system
S3ConfigNotFoundError: Raised when a configuration file is not found in S3
"""
logger.info("Creating auto_ml_job_v2 resource.")
client = Base.get_sagemaker_client(
session=session, region_name=region, service_name="sagemaker"
)
operation_input_args = {
"AutoMLJobName": auto_ml_job_name,
"AutoMLJobInputDataConfig": auto_ml_job_input_data_config,
"OutputDataConfig": output_data_config,
"AutoMLProblemTypeConfig": auto_ml_problem_type_config,
"RoleArn": role_arn,
"Tags": tags,
"SecurityConfig": security_config,
"AutoMLJobObjective": auto_ml_job_objective,
"ModelDeployConfig": model_deploy_config,
"DataSplitConfig": data_split_config,
"AutoMLComputeConfig": auto_ml_compute_config,
}
operation_input_args = Base.populate_chained_attributes(
resource_name="AutoMLJobV2", operation_input_args=operation_input_args
)
logger.debug(f"Input request: {operation_input_args}")
# serialize the input request
operation_input_args = serialize(operation_input_args)
logger.debug(f"Serialized input request: {operation_input_args}")
# create the resource
response = client.create_auto_ml_job_v2(**operation_input_args)
logger.debug(f"Response: {response}")
return cls.get(auto_ml_job_name=auto_ml_job_name, session=session, region=region)
[docs]
@classmethod
@Base.add_validate_call
def get(
cls,
auto_ml_job_name: str,
session: Optional[Session] = None,
region: Optional[str] = None,
) -> Optional["AutoMLJobV2"]:
"""
Get a AutoMLJobV2 resource
Parameters:
auto_ml_job_name: Requests information about an AutoML job V2 using its unique name.
session: Boto3 session.
region: Region name.
Returns:
The AutoMLJobV2 resource.
Raises:
botocore.exceptions.ClientError: This exception is raised for AWS service related errors.
The error message and error code can be parsed from the exception as follows:
```
try:
# AWS service call here
except botocore.exceptions.ClientError as e:
error_message = e.response['Error']['Message']
error_code = e.response['Error']['Code']
```
ResourceNotFound: Resource being access is not found.
"""
operation_input_args = {
"AutoMLJobName": auto_ml_job_name,
}
# serialize the input request
operation_input_args = serialize(operation_input_args)
logger.debug(f"Serialized input request: {operation_input_args}")
client = Base.get_sagemaker_client(
session=session, region_name=region, service_name="sagemaker"
)
response = client.describe_auto_ml_job_v2(**operation_input_args)
logger.debug(response)
# deserialize the response
transformed_response = transform(response, "DescribeAutoMLJobV2Response")
auto_ml_job_v2 = cls(**transformed_response)
return auto_ml_job_v2
[docs]
@Base.add_validate_call
def refresh(
self,
) -> Optional["AutoMLJobV2"]:
"""
Refresh a AutoMLJobV2 resource
Returns:
The AutoMLJobV2 resource.
Raises:
botocore.exceptions.ClientError: This exception is raised for AWS service related errors.
The error message and error code can be parsed from the exception as follows:
```
try:
# AWS service call here
except botocore.exceptions.ClientError as e:
error_message = e.response['Error']['Message']
error_code = e.response['Error']['Code']
```
ResourceNotFound: Resource being access is not found.
"""
operation_input_args = {
"AutoMLJobName": self.auto_ml_job_name,
}
# serialize the input request
operation_input_args = serialize(operation_input_args)
logger.debug(f"Serialized input request: {operation_input_args}")
client = Base.get_sagemaker_client()
response = client.describe_auto_ml_job_v2(**operation_input_args)
# deserialize response and update self
transform(response, "DescribeAutoMLJobV2Response", self)
return self
[docs]
@Base.add_validate_call
def wait(
self,
poll: int = 5,
timeout: Optional[int] = None,
) -> None:
"""
Wait for a AutoMLJobV2 resource.
Parameters:
poll: The number of seconds to wait between each poll.
timeout: The maximum number of seconds to wait before timing out.
Raises:
TimeoutExceededError: If the resource does not reach a terminal state before the timeout.
FailedStatusError: If the resource reaches a failed state.
WaiterError: Raised when an error occurs while waiting.
"""
terminal_states = ["Completed", "Failed", "Stopped"]
start_time = time.time()
progress = Progress(
SpinnerColumn("bouncingBar"),
TextColumn("{task.description}"),
TimeElapsedColumn(),
)
progress.add_task("Waiting for AutoMLJobV2...")
status = Status("Current status:")
with Live(
Panel(
Group(progress, status),
title="Wait Log Panel",
border_style=Style(color=Color.BLUE.value),
),
transient=True,
):
while True:
self.refresh()
current_status = self.auto_ml_job_status
status.update(f"Current status: [bold]{current_status}")
if current_status in terminal_states:
logger.info(f"Final Resource Status: [bold]{current_status}")
if "failed" in current_status.lower():
raise FailedStatusError(
resource_type="AutoMLJobV2",
status=current_status,
reason=self.failure_reason,
)
return
if timeout is not None and time.time() - start_time >= timeout:
raise TimeoutExceededError(resouce_type="AutoMLJobV2", status=current_status)
time.sleep(poll)
[docs]
class Cluster(Base):
"""
Class representing resource Cluster
Attributes:
cluster_arn: The Amazon Resource Name (ARN) of the SageMaker HyperPod cluster.
cluster_status: The status of the SageMaker HyperPod cluster.
instance_groups: The instance groups of the SageMaker HyperPod cluster.
cluster_name: The name of the SageMaker HyperPod cluster.
creation_time: The time when the SageMaker Cluster is created.
failure_message: The failure message of the SageMaker HyperPod cluster.
vpc_config:
orchestrator: The type of orchestrator used for the SageMaker HyperPod cluster.
node_recovery: The node recovery mode configured for the SageMaker HyperPod cluster.
"""
cluster_name: str
cluster_arn: Optional[str] = Unassigned()
cluster_status: Optional[str] = Unassigned()
creation_time: Optional[datetime.datetime] = Unassigned()
failure_message: Optional[str] = Unassigned()
instance_groups: Optional[List[ClusterInstanceGroupDetails]] = Unassigned()
vpc_config: Optional[VpcConfig] = Unassigned()
orchestrator: Optional[ClusterOrchestrator] = Unassigned()
node_recovery: Optional[str] = Unassigned()
def get_name(self) -> str:
attributes = vars(self)
resource_name = "cluster_name"
resource_name_split = resource_name.split("_")
attribute_name_candidates = []
l = len(resource_name_split)
for i in range(0, l):
attribute_name_candidates.append("_".join(resource_name_split[i:l]))
for attribute, value in attributes.items():
if attribute == "name" or attribute in attribute_name_candidates:
return value
logger.error("Name attribute not found for object cluster")
return None
def populate_inputs_decorator(create_func):
@functools.wraps(create_func)
def wrapper(*args, **kwargs):
config_schema_for_resource = {
"vpc_config": {
"security_group_ids": {"type": "array", "items": {"type": "string"}},
"subnets": {"type": "array", "items": {"type": "string"}},
}
}
return create_func(
*args,
**Base.get_updated_kwargs_with_configured_attributes(
config_schema_for_resource, "Cluster", **kwargs
),
)
return wrapper
[docs]
@classmethod
@populate_inputs_decorator
@Base.add_validate_call
def create(
cls,
cluster_name: str,
instance_groups: List[ClusterInstanceGroupSpecification],
vpc_config: Optional[VpcConfig] = Unassigned(),
tags: Optional[List[Tag]] = Unassigned(),
orchestrator: Optional[ClusterOrchestrator] = Unassigned(),
node_recovery: Optional[str] = Unassigned(),
session: Optional[Session] = None,
region: Optional[str] = None,
) -> Optional["Cluster"]:
"""
Create a Cluster resource
Parameters:
cluster_name: The name for the new SageMaker HyperPod cluster.
instance_groups: The instance groups to be created in the SageMaker HyperPod cluster.
vpc_config:
tags: Custom tags for managing the SageMaker HyperPod cluster as an Amazon Web Services resource. You can add tags to your cluster in the same way you add them in other Amazon Web Services services that support tagging. To learn more about tagging Amazon Web Services resources in general, see Tagging Amazon Web Services Resources User Guide.
orchestrator: The type of orchestrator to use for the SageMaker HyperPod cluster. Currently, the only supported value is "eks", which is to use an Amazon Elastic Kubernetes Service (EKS) cluster as the orchestrator.
node_recovery: The node recovery mode for the SageMaker HyperPod cluster. When set to Automatic, SageMaker HyperPod will automatically reboot or replace faulty nodes when issues are detected. When set to None, cluster administrators will need to manually manage any faulty cluster instances.
session: Boto3 session.
region: Region name.
Returns:
The Cluster resource.
Raises:
botocore.exceptions.ClientError: This exception is raised for AWS service related errors.
The error message and error code can be parsed from the exception as follows:
```
try:
# AWS service call here
except botocore.exceptions.ClientError as e:
error_message = e.response['Error']['Message']
error_code = e.response['Error']['Code']
```
ResourceInUse: Resource being accessed is in use.
ResourceLimitExceeded: You have exceeded an SageMaker resource limit. For example, you might have too many training jobs created.
ConfigSchemaValidationError: Raised when a configuration file does not adhere to the schema
LocalConfigNotFoundError: Raised when a configuration file is not found in local file system
S3ConfigNotFoundError: Raised when a configuration file is not found in S3
"""
logger.info("Creating cluster resource.")
client = Base.get_sagemaker_client(
session=session, region_name=region, service_name="sagemaker"
)
operation_input_args = {
"ClusterName": cluster_name,
"InstanceGroups": instance_groups,
"VpcConfig": vpc_config,
"Tags": tags,
"Orchestrator": orchestrator,
"NodeRecovery": node_recovery,
}
operation_input_args = Base.populate_chained_attributes(
resource_name="Cluster", operation_input_args=operation_input_args
)
logger.debug(f"Input request: {operation_input_args}")
# serialize the input request
operation_input_args = serialize(operation_input_args)
logger.debug(f"Serialized input request: {operation_input_args}")
# create the resource
response = client.create_cluster(**operation_input_args)
logger.debug(f"Response: {response}")
return cls.get(cluster_name=cluster_name, session=session, region=region)
[docs]
@classmethod
@Base.add_validate_call
def get(
cls,
cluster_name: str,
session: Optional[Session] = None,
region: Optional[str] = None,
) -> Optional["Cluster"]:
"""
Get a Cluster resource
Parameters:
cluster_name: The string name or the Amazon Resource Name (ARN) of the SageMaker HyperPod cluster.
session: Boto3 session.
region: Region name.
Returns:
The Cluster resource.
Raises:
botocore.exceptions.ClientError: This exception is raised for AWS service related errors.
The error message and error code can be parsed from the exception as follows:
```
try:
# AWS service call here
except botocore.exceptions.ClientError as e:
error_message = e.response['Error']['Message']
error_code = e.response['Error']['Code']
```
ResourceNotFound: Resource being access is not found.
"""
operation_input_args = {
"ClusterName": cluster_name,
}
# serialize the input request
operation_input_args = serialize(operation_input_args)
logger.debug(f"Serialized input request: {operation_input_args}")
client = Base.get_sagemaker_client(
session=session, region_name=region, service_name="sagemaker"
)
response = client.describe_cluster(**operation_input_args)
logger.debug(response)
# deserialize the response
transformed_response = transform(response, "DescribeClusterResponse")
cluster = cls(**transformed_response)
return cluster
[docs]
@Base.add_validate_call
def refresh(
self,
) -> Optional["Cluster"]:
"""
Refresh a Cluster resource
Returns:
The Cluster resource.
Raises:
botocore.exceptions.ClientError: This exception is raised for AWS service related errors.
The error message and error code can be parsed from the exception as follows:
```
try:
# AWS service call here
except botocore.exceptions.ClientError as e:
error_message = e.response['Error']['Message']
error_code = e.response['Error']['Code']
```
ResourceNotFound: Resource being access is not found.
"""
operation_input_args = {
"ClusterName": self.cluster_name,
}
# serialize the input request
operation_input_args = serialize(operation_input_args)
logger.debug(f"Serialized input request: {operation_input_args}")
client = Base.get_sagemaker_client()
response = client.describe_cluster(**operation_input_args)
# deserialize response and update self
transform(response, "DescribeClusterResponse", self)
return self
[docs]
@populate_inputs_decorator
@Base.add_validate_call
def update(
self,
instance_groups: List[ClusterInstanceGroupSpecification],
node_recovery: Optional[str] = Unassigned(),
) -> Optional["Cluster"]:
"""
Update a Cluster resource
Returns:
The Cluster resource.
Raises:
botocore.exceptions.ClientError: This exception is raised for AWS service related errors.
The error message and error code can be parsed from the exception as follows:
```
try:
# AWS service call here
except botocore.exceptions.ClientError as e:
error_message = e.response['Error']['Message']
error_code = e.response['Error']['Code']
```
ConflictException: There was a conflict when you attempted to modify a SageMaker entity such as an Experiment or Artifact.
ResourceLimitExceeded: You have exceeded an SageMaker resource limit. For example, you might have too many training jobs created.
ResourceNotFound: Resource being access is not found.
"""
logger.info("Updating cluster resource.")
client = Base.get_sagemaker_client()
operation_input_args = {
"ClusterName": self.cluster_name,
"InstanceGroups": instance_groups,
"NodeRecovery": node_recovery,
}
logger.debug(f"Input request: {operation_input_args}")
# serialize the input request
operation_input_args = serialize(operation_input_args)
logger.debug(f"Serialized input request: {operation_input_args}")
# create the resource
response = client.update_cluster(**operation_input_args)
logger.debug(f"Response: {response}")
self.refresh()
return self
[docs]
@Base.add_validate_call
def delete(
self,
) -> None:
"""
Delete a Cluster resource
Raises:
botocore.exceptions.ClientError: This exception is raised for AWS service related errors.
The error message and error code can be parsed from the exception as follows:
```
try:
# AWS service call here
except botocore.exceptions.ClientError as e:
error_message = e.response['Error']['Message']
error_code = e.response['Error']['Code']
```
ConflictException: There was a conflict when you attempted to modify a SageMaker entity such as an Experiment or Artifact.
ResourceNotFound: Resource being access is not found.
"""
client = Base.get_sagemaker_client()
operation_input_args = {
"ClusterName": self.cluster_name,
}
# serialize the input request
operation_input_args = serialize(operation_input_args)
logger.debug(f"Serialized input request: {operation_input_args}")
client.delete_cluster(**operation_input_args)
logger.info(f"Deleting {self.__class__.__name__} - {self.get_name()}")
[docs]
@Base.add_validate_call
def wait_for_status(
self,
target_status: Literal[
"Creating",
"Deleting",
"Failed",
"InService",
"RollingBack",
"SystemUpdating",
"Updating",
],
poll: int = 5,
timeout: Optional[int] = None,
) -> None:
"""
Wait for a Cluster resource to reach certain status.
Parameters:
target_status: The status to wait for.
poll: The number of seconds to wait between each poll.
timeout: The maximum number of seconds to wait before timing out.
Raises:
TimeoutExceededError: If the resource does not reach a terminal state before the timeout.
FailedStatusError: If the resource reaches a failed state.
WaiterError: Raised when an error occurs while waiting.
"""
start_time = time.time()
progress = Progress(
SpinnerColumn("bouncingBar"),
TextColumn("{task.description}"),
TimeElapsedColumn(),
)
progress.add_task(f"Waiting for Cluster to reach [bold]{target_status} status...")
status = Status("Current status:")
with Live(
Panel(
Group(progress, status),
title="Wait Log Panel",
border_style=Style(color=Color.BLUE.value),
),
transient=True,
):
while True:
self.refresh()
current_status = self.cluster_status
status.update(f"Current status: [bold]{current_status}")
if target_status == current_status:
logger.info(f"Final Resource Status: [bold]{current_status}")
return
if "failed" in current_status.lower():
raise FailedStatusError(
resource_type="Cluster", status=current_status, reason="(Unknown)"
)
if timeout is not None and time.time() - start_time >= timeout:
raise TimeoutExceededError(resouce_type="Cluster", status=current_status)
time.sleep(poll)
[docs]
@Base.add_validate_call
def wait_for_delete(
self,
poll: int = 5,
timeout: Optional[int] = None,
) -> None:
"""
Wait for a Cluster resource to be deleted.
Parameters:
poll: The number of seconds to wait between each poll.
timeout: The maximum number of seconds to wait before timing out.
Raises:
botocore.exceptions.ClientError: This exception is raised for AWS service related errors.
The error message and error code can be parsed from the exception as follows:
```
try:
# AWS service call here
except botocore.exceptions.ClientError as e:
error_message = e.response['Error']['Message']
error_code = e.response['Error']['Code']
```
TimeoutExceededError: If the resource does not reach a terminal state before the timeout.
DeleteFailedStatusError: If the resource reaches a failed state.
WaiterError: Raised when an error occurs while waiting.
"""
start_time = time.time()
progress = Progress(
SpinnerColumn("bouncingBar"),
TextColumn("{task.description}"),
TimeElapsedColumn(),
)
progress.add_task("Waiting for Cluster to be deleted...")
status = Status("Current status:")
with Live(
Panel(
Group(progress, status),
title="Wait Log Panel",
border_style=Style(color=Color.BLUE.value),
)
):
while True:
try:
self.refresh()
current_status = self.cluster_status
status.update(f"Current status: [bold]{current_status}")
if timeout is not None and time.time() - start_time >= timeout:
raise TimeoutExceededError(resouce_type="Cluster", status=current_status)
except botocore.exceptions.ClientError as e:
error_code = e.response["Error"]["Code"]
if "ResourceNotFound" in error_code or "ValidationException" in error_code:
logger.info("Resource was not found. It may have been deleted.")
return
raise e
time.sleep(poll)
[docs]
@classmethod
@Base.add_validate_call
def get_all(
cls,
creation_time_after: Optional[datetime.datetime] = Unassigned(),
creation_time_before: Optional[datetime.datetime] = Unassigned(),
name_contains: Optional[str] = Unassigned(),
sort_by: Optional[str] = Unassigned(),
sort_order: Optional[str] = Unassigned(),
session: Optional[Session] = None,
region: Optional[str] = None,
) -> ResourceIterator["Cluster"]:
"""
Get all Cluster resources
Parameters:
creation_time_after: Set a start time for the time range during which you want to list SageMaker HyperPod clusters. Timestamps are formatted according to the ISO 8601 standard. Acceptable formats include: YYYY-MM-DDThh:mm:ss.sssTZD (UTC), for example, 2014-10-01T20:30:00.000Z YYYY-MM-DDThh:mm:ss.sssTZD (with offset), for example, 2014-10-01T12:30:00.000-08:00 YYYY-MM-DD, for example, 2014-10-01 Unix time in seconds, for example, 1412195400. This is also referred to as Unix Epoch time and represents the number of seconds since midnight, January 1, 1970 UTC. For more information about the timestamp format, see Timestamp in the Amazon Web Services Command Line Interface User Guide.
creation_time_before: Set an end time for the time range during which you want to list SageMaker HyperPod clusters. A filter that returns nodes in a SageMaker HyperPod cluster created before the specified time. The acceptable formats are the same as the timestamp formats for CreationTimeAfter. For more information about the timestamp format, see Timestamp in the Amazon Web Services Command Line Interface User Guide.
max_results: Set the maximum number of SageMaker HyperPod clusters to list.
name_contains: Set the maximum number of instances to print in the list.
next_token: Set the next token to retrieve the list of SageMaker HyperPod clusters.
sort_by: The field by which to sort results. The default value is CREATION_TIME.
sort_order: The sort order for results. The default value is Ascending.
session: Boto3 session.
region: Region name.
Returns:
Iterator for listed Cluster resources.
Raises:
botocore.exceptions.ClientError: This exception is raised for AWS service related errors.
The error message and error code can be parsed from the exception as follows:
```
try:
# AWS service call here
except botocore.exceptions.ClientError as e:
error_message = e.response['Error']['Message']
error_code = e.response['Error']['Code']
```
"""
client = Base.get_sagemaker_client(
session=session, region_name=region, service_name="sagemaker"
)
operation_input_args = {
"CreationTimeAfter": creation_time_after,
"CreationTimeBefore": creation_time_before,
"NameContains": name_contains,
"SortBy": sort_by,
"SortOrder": sort_order,
}
# serialize the input request
operation_input_args = serialize(operation_input_args)
logger.debug(f"Serialized input request: {operation_input_args}")
return ResourceIterator(
client=client,
list_method="list_clusters",
summaries_key="ClusterSummaries",
summary_name="ClusterSummary",
resource_cls=Cluster,
list_method_kwargs=operation_input_args,
)
[docs]
@Base.add_validate_call
def get_node(
self,
node_id: str,
session: Optional[Session] = None,
region: Optional[str] = None,
) -> Optional[ClusterNodeDetails]:
"""
Retrieves information of a node (also called a instance interchangeably) of a SageMaker HyperPod cluster.
Parameters:
node_id: The ID of the SageMaker HyperPod cluster node.
session: Boto3 session.
region: Region name.
Returns:
ClusterNodeDetails
Raises:
botocore.exceptions.ClientError: This exception is raised for AWS service related errors.
The error message and error code can be parsed from the exception as follows:
```
try:
# AWS service call here
except botocore.exceptions.ClientError as e:
error_message = e.response['Error']['Message']
error_code = e.response['Error']['Code']
```
ResourceNotFound: Resource being access is not found.
"""
operation_input_args = {
"ClusterName": self.cluster_name,
"NodeId": node_id,
}
# serialize the input request
operation_input_args = serialize(operation_input_args)
logger.debug(f"Serialized input request: {operation_input_args}")
client = Base.get_sagemaker_client(
session=session, region_name=region, service_name="sagemaker"
)
logger.debug(f"Calling describe_cluster_node API")
response = client.describe_cluster_node(**operation_input_args)
logger.debug(f"Response: {response}")
transformed_response = transform(response, "DescribeClusterNodeResponse")
return ClusterNodeDetails(**transformed_response)
[docs]
@Base.add_validate_call
def get_all_nodes(
self,
creation_time_after: Optional[datetime.datetime] = Unassigned(),
creation_time_before: Optional[datetime.datetime] = Unassigned(),
instance_group_name_contains: Optional[str] = Unassigned(),
sort_by: Optional[str] = Unassigned(),
sort_order: Optional[str] = Unassigned(),
session: Optional[Session] = None,
region: Optional[str] = None,
) -> ResourceIterator[ClusterNodeDetails]:
"""
Retrieves the list of instances (also called nodes interchangeably) in a SageMaker HyperPod cluster.
Parameters:
creation_time_after: A filter that returns nodes in a SageMaker HyperPod cluster created after the specified time. Timestamps are formatted according to the ISO 8601 standard. Acceptable formats include: YYYY-MM-DDThh:mm:ss.sssTZD (UTC), for example, 2014-10-01T20:30:00.000Z YYYY-MM-DDThh:mm:ss.sssTZD (with offset), for example, 2014-10-01T12:30:00.000-08:00 YYYY-MM-DD, for example, 2014-10-01 Unix time in seconds, for example, 1412195400. This is also referred to as Unix Epoch time and represents the number of seconds since midnight, January 1, 1970 UTC. For more information about the timestamp format, see Timestamp in the Amazon Web Services Command Line Interface User Guide.
creation_time_before: A filter that returns nodes in a SageMaker HyperPod cluster created before the specified time. The acceptable formats are the same as the timestamp formats for CreationTimeAfter. For more information about the timestamp format, see Timestamp in the Amazon Web Services Command Line Interface User Guide.
instance_group_name_contains: A filter that returns the instance groups whose name contain a specified string.
max_results: The maximum number of nodes to return in the response.
next_token: If the result of the previous ListClusterNodes request was truncated, the response includes a NextToken. To retrieve the next set of cluster nodes, use the token in the next request.
sort_by: The field by which to sort results. The default value is CREATION_TIME.
sort_order: The sort order for results. The default value is Ascending.
session: Boto3 session.
region: Region name.
Returns:
Iterator for listed ClusterNodeDetails.
Raises:
botocore.exceptions.ClientError: This exception is raised for AWS service related errors.
The error message and error code can be parsed from the exception as follows:
```
try:
# AWS service call here
except botocore.exceptions.ClientError as e:
error_message = e.response['Error']['Message']
error_code = e.response['Error']['Code']
```
ResourceNotFound: Resource being access is not found.
"""
operation_input_args = {
"ClusterName": self.cluster_name,
"CreationTimeAfter": creation_time_after,
"CreationTimeBefore": creation_time_before,
"InstanceGroupNameContains": instance_group_name_contains,
"SortBy": sort_by,
"SortOrder": sort_order,
}
# serialize the input request
operation_input_args = serialize(operation_input_args)
logger.debug(f"Serialized input request: {operation_input_args}")
client = Base.get_sagemaker_client(
session=session, region_name=region, service_name="sagemaker"
)
return ResourceIterator(
client=client,
list_method="list_cluster_nodes",
summaries_key="ClusterNodeSummaries",
summary_name="ClusterNodeSummary",
resource_cls=ClusterNodeDetails,
list_method_kwargs=operation_input_args,
)
[docs]
@Base.add_validate_call
def update_software(
self,
session: Optional[Session] = None,
region: Optional[str] = None,
) -> None:
"""
Updates the platform software of a SageMaker HyperPod cluster for security patching.
Parameters:
session: Boto3 session.
region: Region name.
Raises:
botocore.exceptions.ClientError: This exception is raised for AWS service related errors.
The error message and error code can be parsed from the exception as follows:
```
try:
# AWS service call here
except botocore.exceptions.ClientError as e:
error_message = e.response['Error']['Message']
error_code = e.response['Error']['Code']
```
ConflictException: There was a conflict when you attempted to modify a SageMaker entity such as an Experiment or Artifact.
ResourceNotFound: Resource being access is not found.
"""
operation_input_args = {
"ClusterName": self.cluster_name,
}
# serialize the input request
operation_input_args = serialize(operation_input_args)
logger.debug(f"Serialized input request: {operation_input_args}")
client = Base.get_sagemaker_client(
session=session, region_name=region, service_name="sagemaker"
)
logger.debug(f"Calling update_cluster_software API")
response = client.update_cluster_software(**operation_input_args)
logger.debug(f"Response: {response}")
[docs]
@Base.add_validate_call
def batch_delete_nodes(
self,
node_ids: List[str],
session: Optional[Session] = None,
region: Optional[str] = None,
) -> Optional[BatchDeleteClusterNodesResponse]:
"""
Deletes specific nodes within a SageMaker HyperPod cluster.
Parameters:
node_ids: A list of node IDs to be deleted from the specified cluster. For SageMaker HyperPod clusters using the Slurm workload manager, you cannot remove instances that are configured as Slurm controller nodes.
session: Boto3 session.
region: Region name.
Returns:
BatchDeleteClusterNodesResponse
Raises:
botocore.exceptions.ClientError: This exception is raised for AWS service related errors.
The error message and error code can be parsed from the exception as follows:
```
try:
# AWS service call here
except botocore.exceptions.ClientError as e:
error_message = e.response['Error']['Message']
error_code = e.response['Error']['Code']
```
ResourceNotFound: Resource being access is not found.
"""
operation_input_args = {
"ClusterName": self.cluster_name,
"NodeIds": node_ids,
}
# serialize the input request
operation_input_args = serialize(operation_input_args)
logger.debug(f"Serialized input request: {operation_input_args}")
client = Base.get_sagemaker_client(
session=session, region_name=region, service_name="sagemaker"
)
logger.debug(f"Calling batch_delete_cluster_nodes API")
response = client.batch_delete_cluster_nodes(**operation_input_args)
logger.debug(f"Response: {response}")
transformed_response = transform(response, "BatchDeleteClusterNodesResponse")
return BatchDeleteClusterNodesResponse(**transformed_response)
[docs]
class CodeRepository(Base):
"""
Class representing resource CodeRepository
Attributes:
code_repository_name: The name of the Git repository.
code_repository_arn: The Amazon Resource Name (ARN) of the Git repository.
creation_time: The date and time that the repository was created.
last_modified_time: The date and time that the repository was last changed.
git_config: Configuration details about the repository, including the URL where the repository is located, the default branch, and the Amazon Resource Name (ARN) of the Amazon Web Services Secrets Manager secret that contains the credentials used to access the repository.
"""
code_repository_name: str
code_repository_arn: Optional[str] = Unassigned()
creation_time: Optional[datetime.datetime] = Unassigned()
last_modified_time: Optional[datetime.datetime] = Unassigned()
git_config: Optional[GitConfig] = Unassigned()
def get_name(self) -> str:
attributes = vars(self)
resource_name = "code_repository_name"
resource_name_split = resource_name.split("_")
attribute_name_candidates = []
l = len(resource_name_split)
for i in range(0, l):
attribute_name_candidates.append("_".join(resource_name_split[i:l]))
for attribute, value in attributes.items():
if attribute == "name" or attribute in attribute_name_candidates:
return value
logger.error("Name attribute not found for object code_repository")
return None
[docs]
@classmethod
@Base.add_validate_call
def create(
cls,
code_repository_name: str,
git_config: GitConfig,
tags: Optional[List[Tag]] = Unassigned(),
session: Optional[Session] = None,
region: Optional[str] = None,
) -> Optional["CodeRepository"]:
"""
Create a CodeRepository resource
Parameters:
code_repository_name: The name of the Git repository. The name must have 1 to 63 characters. Valid characters are a-z, A-Z, 0-9, and - (hyphen).
git_config: Specifies details about the repository, including the URL where the repository is located, the default branch, and credentials to use to access the repository.
tags: An array of key-value pairs. You can use tags to categorize your Amazon Web Services resources in different ways, for example, by purpose, owner, or environment. For more information, see Tagging Amazon Web Services Resources.
session: Boto3 session.
region: Region name.
Returns:
The CodeRepository resource.
Raises:
botocore.exceptions.ClientError: This exception is raised for AWS service related errors.
The error message and error code can be parsed from the exception as follows:
```
try:
# AWS service call here
except botocore.exceptions.ClientError as e:
error_message = e.response['Error']['Message']
error_code = e.response['Error']['Code']
```
ConfigSchemaValidationError: Raised when a configuration file does not adhere to the schema
LocalConfigNotFoundError: Raised when a configuration file is not found in local file system
S3ConfigNotFoundError: Raised when a configuration file is not found in S3
"""
logger.info("Creating code_repository resource.")
client = Base.get_sagemaker_client(
session=session, region_name=region, service_name="sagemaker"
)
operation_input_args = {
"CodeRepositoryName": code_repository_name,
"GitConfig": git_config,
"Tags": tags,
}
operation_input_args = Base.populate_chained_attributes(
resource_name="CodeRepository", operation_input_args=operation_input_args
)
logger.debug(f"Input request: {operation_input_args}")
# serialize the input request
operation_input_args = serialize(operation_input_args)
logger.debug(f"Serialized input request: {operation_input_args}")
# create the resource
response = client.create_code_repository(**operation_input_args)
logger.debug(f"Response: {response}")
return cls.get(code_repository_name=code_repository_name, session=session, region=region)
[docs]
@classmethod
@Base.add_validate_call
def get(
cls,
code_repository_name: str,
session: Optional[Session] = None,
region: Optional[str] = None,
) -> Optional["CodeRepository"]:
"""
Get a CodeRepository resource
Parameters:
code_repository_name: The name of the Git repository to describe.
session: Boto3 session.
region: Region name.
Returns:
The CodeRepository resource.
Raises:
botocore.exceptions.ClientError: This exception is raised for AWS service related errors.
The error message and error code can be parsed from the exception as follows:
```
try:
# AWS service call here
except botocore.exceptions.ClientError as e:
error_message = e.response['Error']['Message']
error_code = e.response['Error']['Code']
```
"""
operation_input_args = {
"CodeRepositoryName": code_repository_name,
}
# serialize the input request
operation_input_args = serialize(operation_input_args)
logger.debug(f"Serialized input request: {operation_input_args}")
client = Base.get_sagemaker_client(
session=session, region_name=region, service_name="sagemaker"
)
response = client.describe_code_repository(**operation_input_args)
logger.debug(response)
# deserialize the response
transformed_response = transform(response, "DescribeCodeRepositoryOutput")
code_repository = cls(**transformed_response)
return code_repository
[docs]
@Base.add_validate_call
def refresh(
self,
) -> Optional["CodeRepository"]:
"""
Refresh a CodeRepository resource
Returns:
The CodeRepository resource.
Raises:
botocore.exceptions.ClientError: This exception is raised for AWS service related errors.
The error message and error code can be parsed from the exception as follows:
```
try:
# AWS service call here
except botocore.exceptions.ClientError as e:
error_message = e.response['Error']['Message']
error_code = e.response['Error']['Code']
```
"""
operation_input_args = {
"CodeRepositoryName": self.code_repository_name,
}
# serialize the input request
operation_input_args = serialize(operation_input_args)
logger.debug(f"Serialized input request: {operation_input_args}")
client = Base.get_sagemaker_client()
response = client.describe_code_repository(**operation_input_args)
# deserialize response and update self
transform(response, "DescribeCodeRepositoryOutput", self)
return self
[docs]
@Base.add_validate_call
def update(
self,
git_config: Optional[GitConfigForUpdate] = Unassigned(),
) -> Optional["CodeRepository"]:
"""
Update a CodeRepository resource
Returns:
The CodeRepository resource.
Raises:
botocore.exceptions.ClientError: This exception is raised for AWS service related errors.
The error message and error code can be parsed from the exception as follows:
```
try:
# AWS service call here
except botocore.exceptions.ClientError as e:
error_message = e.response['Error']['Message']
error_code = e.response['Error']['Code']
```
ConflictException: There was a conflict when you attempted to modify a SageMaker entity such as an Experiment or Artifact.
"""
logger.info("Updating code_repository resource.")
client = Base.get_sagemaker_client()
operation_input_args = {
"CodeRepositoryName": self.code_repository_name,
"GitConfig": git_config,
}
logger.debug(f"Input request: {operation_input_args}")
# serialize the input request
operation_input_args = serialize(operation_input_args)
logger.debug(f"Serialized input request: {operation_input_args}")
# create the resource
response = client.update_code_repository(**operation_input_args)
logger.debug(f"Response: {response}")
self.refresh()
return self
[docs]
@Base.add_validate_call
def delete(
self,
) -> None:
"""
Delete a CodeRepository resource
Raises:
botocore.exceptions.ClientError: This exception is raised for AWS service related errors.
The error message and error code can be parsed from the exception as follows:
```
try:
# AWS service call here
except botocore.exceptions.ClientError as e:
error_message = e.response['Error']['Message']
error_code = e.response['Error']['Code']
```
"""
client = Base.get_sagemaker_client()
operation_input_args = {
"CodeRepositoryName": self.code_repository_name,
}
# serialize the input request
operation_input_args = serialize(operation_input_args)
logger.debug(f"Serialized input request: {operation_input_args}")
client.delete_code_repository(**operation_input_args)
logger.info(f"Deleting {self.__class__.__name__} - {self.get_name()}")
[docs]
@classmethod
@Base.add_validate_call
def get_all(
cls,
creation_time_after: Optional[datetime.datetime] = Unassigned(),
creation_time_before: Optional[datetime.datetime] = Unassigned(),
last_modified_time_after: Optional[datetime.datetime] = Unassigned(),
last_modified_time_before: Optional[datetime.datetime] = Unassigned(),
name_contains: Optional[str] = Unassigned(),
sort_by: Optional[str] = Unassigned(),
sort_order: Optional[str] = Unassigned(),
session: Optional[Session] = None,
region: Optional[str] = None,
) -> ResourceIterator["CodeRepository"]:
"""
Gets a list of the Git repositories in your account.
Parameters:
creation_time_after: A filter that returns only Git repositories that were created after the specified time.
creation_time_before: A filter that returns only Git repositories that were created before the specified time.
last_modified_time_after: A filter that returns only Git repositories that were last modified after the specified time.
last_modified_time_before: A filter that returns only Git repositories that were last modified before the specified time.
max_results: The maximum number of Git repositories to return in the response.
name_contains: A string in the Git repositories name. This filter returns only repositories whose name contains the specified string.
next_token: If the result of a ListCodeRepositoriesOutput request was truncated, the response includes a NextToken. To get the next set of Git repositories, use the token in the next request.
sort_by: The field to sort results by. The default is Name.
sort_order: The sort order for results. The default is Ascending.
session: Boto3 session.
region: Region name.
Returns:
Iterator for listed CodeRepository.
Raises:
botocore.exceptions.ClientError: This exception is raised for AWS service related errors.
The error message and error code can be parsed from the exception as follows:
```
try:
# AWS service call here
except botocore.exceptions.ClientError as e:
error_message = e.response['Error']['Message']
error_code = e.response['Error']['Code']
```
"""
operation_input_args = {
"CreationTimeAfter": creation_time_after,
"CreationTimeBefore": creation_time_before,
"LastModifiedTimeAfter": last_modified_time_after,
"LastModifiedTimeBefore": last_modified_time_before,
"NameContains": name_contains,
"SortBy": sort_by,
"SortOrder": sort_order,
}
# serialize the input request
operation_input_args = serialize(operation_input_args)
logger.debug(f"Serialized input request: {operation_input_args}")
client = Base.get_sagemaker_client(
session=session, region_name=region, service_name="sagemaker"
)
return ResourceIterator(
client=client,
list_method="list_code_repositories",
summaries_key="CodeRepositorySummaryList",
summary_name="CodeRepositorySummary",
resource_cls=CodeRepository,
list_method_kwargs=operation_input_args,
)
[docs]
class CompilationJob(Base):
"""
Class representing resource CompilationJob
Attributes:
compilation_job_name: The name of the model compilation job.
compilation_job_arn: The Amazon Resource Name (ARN) of the model compilation job.
compilation_job_status: The status of the model compilation job.
stopping_condition: Specifies a limit to how long a model compilation job can run. When the job reaches the time limit, Amazon SageMaker ends the compilation job. Use this API to cap model training costs.
creation_time: The time that the model compilation job was created.
last_modified_time: The time that the status of the model compilation job was last modified.
failure_reason: If a model compilation job failed, the reason it failed.
model_artifacts: Information about the location in Amazon S3 that has been configured for storing the model artifacts used in the compilation job.
role_arn: The Amazon Resource Name (ARN) of an IAM role that Amazon SageMaker assumes to perform the model compilation job.
input_config: Information about the location in Amazon S3 of the input model artifacts, the name and shape of the expected data inputs, and the framework in which the model was trained.
output_config: Information about the output location for the compiled model and the target device that the model runs on.
compilation_start_time: The time when the model compilation job started the CompilationJob instances. You are billed for the time between this timestamp and the timestamp in the CompilationEndTime field. In Amazon CloudWatch Logs, the start time might be later than this time. That's because it takes time to download the compilation job, which depends on the size of the compilation job container.
compilation_end_time: The time when the model compilation job on a compilation job instance ended. For a successful or stopped job, this is when the job's model artifacts have finished uploading. For a failed job, this is when Amazon SageMaker detected that the job failed.
inference_image: The inference image to use when compiling a model. Specify an image only if the target device is a cloud instance.
model_package_version_arn: The Amazon Resource Name (ARN) of the versioned model package that was provided to SageMaker Neo when you initiated a compilation job.
model_digests: Provides a BLAKE2 hash value that identifies the compiled model artifacts in Amazon S3.
vpc_config: A VpcConfig object that specifies the VPC that you want your compilation job to connect to. Control access to your models by configuring the VPC. For more information, see Protect Compilation Jobs by Using an Amazon Virtual Private Cloud.
derived_information: Information that SageMaker Neo automatically derived about the model.
"""
compilation_job_name: str
compilation_job_arn: Optional[str] = Unassigned()
compilation_job_status: Optional[str] = Unassigned()
compilation_start_time: Optional[datetime.datetime] = Unassigned()
compilation_end_time: Optional[datetime.datetime] = Unassigned()
stopping_condition: Optional[StoppingCondition] = Unassigned()
inference_image: Optional[str] = Unassigned()
model_package_version_arn: Optional[str] = Unassigned()
creation_time: Optional[datetime.datetime] = Unassigned()
last_modified_time: Optional[datetime.datetime] = Unassigned()
failure_reason: Optional[str] = Unassigned()
model_artifacts: Optional[ModelArtifacts] = Unassigned()
model_digests: Optional[ModelDigests] = Unassigned()
role_arn: Optional[str] = Unassigned()
input_config: Optional[InputConfig] = Unassigned()
output_config: Optional[OutputConfig] = Unassigned()
vpc_config: Optional[NeoVpcConfig] = Unassigned()
derived_information: Optional[DerivedInformation] = Unassigned()
def get_name(self) -> str:
attributes = vars(self)
resource_name = "compilation_job_name"
resource_name_split = resource_name.split("_")
attribute_name_candidates = []
l = len(resource_name_split)
for i in range(0, l):
attribute_name_candidates.append("_".join(resource_name_split[i:l]))
for attribute, value in attributes.items():
if attribute == "name" or attribute in attribute_name_candidates:
return value
logger.error("Name attribute not found for object compilation_job")
return None
def populate_inputs_decorator(create_func):
@functools.wraps(create_func)
def wrapper(*args, **kwargs):
config_schema_for_resource = {
"model_artifacts": {"s3_model_artifacts": {"type": "string"}},
"role_arn": {"type": "string"},
"input_config": {"s3_uri": {"type": "string"}},
"output_config": {
"s3_output_location": {"type": "string"},
"kms_key_id": {"type": "string"},
},
"vpc_config": {
"security_group_ids": {"type": "array", "items": {"type": "string"}},
"subnets": {"type": "array", "items": {"type": "string"}},
},
}
return create_func(
*args,
**Base.get_updated_kwargs_with_configured_attributes(
config_schema_for_resource, "CompilationJob", **kwargs
),
)
return wrapper
[docs]
@classmethod
@populate_inputs_decorator
@Base.add_validate_call
def create(
cls,
compilation_job_name: str,
role_arn: str,
output_config: OutputConfig,
stopping_condition: StoppingCondition,
model_package_version_arn: Optional[str] = Unassigned(),
input_config: Optional[InputConfig] = Unassigned(),
vpc_config: Optional[NeoVpcConfig] = Unassigned(),
tags: Optional[List[Tag]] = Unassigned(),
session: Optional[Session] = None,
region: Optional[str] = None,
) -> Optional["CompilationJob"]:
"""
Create a CompilationJob resource
Parameters:
compilation_job_name: A name for the model compilation job. The name must be unique within the Amazon Web Services Region and within your Amazon Web Services account.
role_arn: The Amazon Resource Name (ARN) of an IAM role that enables Amazon SageMaker to perform tasks on your behalf. During model compilation, Amazon SageMaker needs your permission to: Read input data from an S3 bucket Write model artifacts to an S3 bucket Write logs to Amazon CloudWatch Logs Publish metrics to Amazon CloudWatch You grant permissions for all of these tasks to an IAM role. To pass this role to Amazon SageMaker, the caller of this API must have the iam:PassRole permission. For more information, see Amazon SageMaker Roles.
output_config: Provides information about the output location for the compiled model and the target device the model runs on.
stopping_condition: Specifies a limit to how long a model compilation job can run. When the job reaches the time limit, Amazon SageMaker ends the compilation job. Use this API to cap model training costs.
model_package_version_arn: The Amazon Resource Name (ARN) of a versioned model package. Provide either a ModelPackageVersionArn or an InputConfig object in the request syntax. The presence of both objects in the CreateCompilationJob request will return an exception.
input_config: Provides information about the location of input model artifacts, the name and shape of the expected data inputs, and the framework in which the model was trained.
vpc_config: A VpcConfig object that specifies the VPC that you want your compilation job to connect to. Control access to your models by configuring the VPC. For more information, see Protect Compilation Jobs by Using an Amazon Virtual Private Cloud.
tags: An array of key-value pairs. You can use tags to categorize your Amazon Web Services resources in different ways, for example, by purpose, owner, or environment. For more information, see Tagging Amazon Web Services Resources.
session: Boto3 session.
region: Region name.
Returns:
The CompilationJob resource.
Raises:
botocore.exceptions.ClientError: This exception is raised for AWS service related errors.
The error message and error code can be parsed from the exception as follows:
```
try:
# AWS service call here
except botocore.exceptions.ClientError as e:
error_message = e.response['Error']['Message']
error_code = e.response['Error']['Code']
```
ResourceInUse: Resource being accessed is in use.
ResourceLimitExceeded: You have exceeded an SageMaker resource limit. For example, you might have too many training jobs created.
ConfigSchemaValidationError: Raised when a configuration file does not adhere to the schema
LocalConfigNotFoundError: Raised when a configuration file is not found in local file system
S3ConfigNotFoundError: Raised when a configuration file is not found in S3
"""
logger.info("Creating compilation_job resource.")
client = Base.get_sagemaker_client(
session=session, region_name=region, service_name="sagemaker"
)
operation_input_args = {
"CompilationJobName": compilation_job_name,
"RoleArn": role_arn,
"ModelPackageVersionArn": model_package_version_arn,
"InputConfig": input_config,
"OutputConfig": output_config,
"VpcConfig": vpc_config,
"StoppingCondition": stopping_condition,
"Tags": tags,
}
operation_input_args = Base.populate_chained_attributes(
resource_name="CompilationJob", operation_input_args=operation_input_args
)
logger.debug(f"Input request: {operation_input_args}")
# serialize the input request
operation_input_args = serialize(operation_input_args)
logger.debug(f"Serialized input request: {operation_input_args}")
# create the resource
response = client.create_compilation_job(**operation_input_args)
logger.debug(f"Response: {response}")
return cls.get(compilation_job_name=compilation_job_name, session=session, region=region)
[docs]
@classmethod
@Base.add_validate_call
def get(
cls,
compilation_job_name: str,
session: Optional[Session] = None,
region: Optional[str] = None,
) -> Optional["CompilationJob"]:
"""
Get a CompilationJob resource
Parameters:
compilation_job_name: The name of the model compilation job that you want information about.
session: Boto3 session.
region: Region name.
Returns:
The CompilationJob resource.
Raises:
botocore.exceptions.ClientError: This exception is raised for AWS service related errors.
The error message and error code can be parsed from the exception as follows:
```
try:
# AWS service call here
except botocore.exceptions.ClientError as e:
error_message = e.response['Error']['Message']
error_code = e.response['Error']['Code']
```
ResourceNotFound: Resource being access is not found.
"""
operation_input_args = {
"CompilationJobName": compilation_job_name,
}
# serialize the input request
operation_input_args = serialize(operation_input_args)
logger.debug(f"Serialized input request: {operation_input_args}")
client = Base.get_sagemaker_client(
session=session, region_name=region, service_name="sagemaker"
)
response = client.describe_compilation_job(**operation_input_args)
logger.debug(response)
# deserialize the response
transformed_response = transform(response, "DescribeCompilationJobResponse")
compilation_job = cls(**transformed_response)
return compilation_job
[docs]
@Base.add_validate_call
def refresh(
self,
) -> Optional["CompilationJob"]:
"""
Refresh a CompilationJob resource
Returns:
The CompilationJob resource.
Raises:
botocore.exceptions.ClientError: This exception is raised for AWS service related errors.
The error message and error code can be parsed from the exception as follows:
```
try:
# AWS service call here
except botocore.exceptions.ClientError as e:
error_message = e.response['Error']['Message']
error_code = e.response['Error']['Code']
```
ResourceNotFound: Resource being access is not found.
"""
operation_input_args = {
"CompilationJobName": self.compilation_job_name,
}
# serialize the input request
operation_input_args = serialize(operation_input_args)
logger.debug(f"Serialized input request: {operation_input_args}")
client = Base.get_sagemaker_client()
response = client.describe_compilation_job(**operation_input_args)
# deserialize response and update self
transform(response, "DescribeCompilationJobResponse", self)
return self
[docs]
@Base.add_validate_call
def delete(
self,
) -> None:
"""
Delete a CompilationJob resource
Raises:
botocore.exceptions.ClientError: This exception is raised for AWS service related errors.
The error message and error code can be parsed from the exception as follows:
```
try:
# AWS service call here
except botocore.exceptions.ClientError as e:
error_message = e.response['Error']['Message']
error_code = e.response['Error']['Code']
```
ResourceNotFound: Resource being access is not found.
"""
client = Base.get_sagemaker_client()
operation_input_args = {
"CompilationJobName": self.compilation_job_name,
}
# serialize the input request
operation_input_args = serialize(operation_input_args)
logger.debug(f"Serialized input request: {operation_input_args}")
client.delete_compilation_job(**operation_input_args)
logger.info(f"Deleting {self.__class__.__name__} - {self.get_name()}")
[docs]
@Base.add_validate_call
def stop(self) -> None:
"""
Stop a CompilationJob resource
Raises:
botocore.exceptions.ClientError: This exception is raised for AWS service related errors.
The error message and error code can be parsed from the exception as follows:
```
try:
# AWS service call here
except botocore.exceptions.ClientError as e:
error_message = e.response['Error']['Message']
error_code = e.response['Error']['Code']
```
ResourceNotFound: Resource being access is not found.
"""
client = SageMakerClient().client
operation_input_args = {
"CompilationJobName": self.compilation_job_name,
}
# serialize the input request
operation_input_args = serialize(operation_input_args)
logger.debug(f"Serialized input request: {operation_input_args}")
client.stop_compilation_job(**operation_input_args)
logger.info(f"Stopping {self.__class__.__name__} - {self.get_name()}")
[docs]
@Base.add_validate_call
def wait(
self,
poll: int = 5,
timeout: Optional[int] = None,
) -> None:
"""
Wait for a CompilationJob resource.
Parameters:
poll: The number of seconds to wait between each poll.
timeout: The maximum number of seconds to wait before timing out.
Raises:
TimeoutExceededError: If the resource does not reach a terminal state before the timeout.
FailedStatusError: If the resource reaches a failed state.
WaiterError: Raised when an error occurs while waiting.
"""
terminal_states = ["COMPLETED", "FAILED", "STOPPED"]
start_time = time.time()
progress = Progress(
SpinnerColumn("bouncingBar"),
TextColumn("{task.description}"),
TimeElapsedColumn(),
)
progress.add_task("Waiting for CompilationJob...")
status = Status("Current status:")
with Live(
Panel(
Group(progress, status),
title="Wait Log Panel",
border_style=Style(color=Color.BLUE.value),
),
transient=True,
):
while True:
self.refresh()
current_status = self.compilation_job_status
status.update(f"Current status: [bold]{current_status}")
if current_status in terminal_states:
logger.info(f"Final Resource Status: [bold]{current_status}")
if "failed" in current_status.lower():
raise FailedStatusError(
resource_type="CompilationJob",
status=current_status,
reason=self.failure_reason,
)
return
if timeout is not None and time.time() - start_time >= timeout:
raise TimeoutExceededError(resouce_type="CompilationJob", status=current_status)
time.sleep(poll)
[docs]
@classmethod
@Base.add_validate_call
def get_all(
cls,
creation_time_after: Optional[datetime.datetime] = Unassigned(),
creation_time_before: Optional[datetime.datetime] = Unassigned(),
last_modified_time_after: Optional[datetime.datetime] = Unassigned(),
last_modified_time_before: Optional[datetime.datetime] = Unassigned(),
name_contains: Optional[str] = Unassigned(),
status_equals: Optional[str] = Unassigned(),
sort_by: Optional[str] = Unassigned(),
sort_order: Optional[str] = Unassigned(),
session: Optional[Session] = None,
region: Optional[str] = None,
) -> ResourceIterator["CompilationJob"]:
"""
Get all CompilationJob resources
Parameters:
next_token: If the result of the previous ListCompilationJobs request was truncated, the response includes a NextToken. To retrieve the next set of model compilation jobs, use the token in the next request.
max_results: The maximum number of model compilation jobs to return in the response.
creation_time_after: A filter that returns the model compilation jobs that were created after a specified time.
creation_time_before: A filter that returns the model compilation jobs that were created before a specified time.
last_modified_time_after: A filter that returns the model compilation jobs that were modified after a specified time.
last_modified_time_before: A filter that returns the model compilation jobs that were modified before a specified time.
name_contains: A filter that returns the model compilation jobs whose name contains a specified string.
status_equals: A filter that retrieves model compilation jobs with a specific CompilationJobStatus status.
sort_by: The field by which to sort results. The default is CreationTime.
sort_order: The sort order for results. The default is Ascending.
session: Boto3 session.
region: Region name.
Returns:
Iterator for listed CompilationJob resources.
Raises:
botocore.exceptions.ClientError: This exception is raised for AWS service related errors.
The error message and error code can be parsed from the exception as follows:
```
try:
# AWS service call here
except botocore.exceptions.ClientError as e:
error_message = e.response['Error']['Message']
error_code = e.response['Error']['Code']
```
"""
client = Base.get_sagemaker_client(
session=session, region_name=region, service_name="sagemaker"
)
operation_input_args = {
"CreationTimeAfter": creation_time_after,
"CreationTimeBefore": creation_time_before,
"LastModifiedTimeAfter": last_modified_time_after,
"LastModifiedTimeBefore": last_modified_time_before,
"NameContains": name_contains,
"StatusEquals": status_equals,
"SortBy": sort_by,
"SortOrder": sort_order,
}
# serialize the input request
operation_input_args = serialize(operation_input_args)
logger.debug(f"Serialized input request: {operation_input_args}")
return ResourceIterator(
client=client,
list_method="list_compilation_jobs",
summaries_key="CompilationJobSummaries",
summary_name="CompilationJobSummary",
resource_cls=CompilationJob,
list_method_kwargs=operation_input_args,
)
[docs]
class Context(Base):
"""
Class representing resource Context
Attributes:
context_name: The name of the context.
context_arn: The Amazon Resource Name (ARN) of the context.
source: The source of the context.
context_type: The type of the context.
description: The description of the context.
properties: A list of the context's properties.
creation_time: When the context was created.
created_by:
last_modified_time: When the context was last modified.
last_modified_by:
lineage_group_arn: The Amazon Resource Name (ARN) of the lineage group.
"""
context_name: str
context_arn: Optional[str] = Unassigned()
source: Optional[ContextSource] = Unassigned()
context_type: Optional[str] = Unassigned()
description: Optional[str] = Unassigned()
properties: Optional[Dict[str, str]] = Unassigned()
creation_time: Optional[datetime.datetime] = Unassigned()
created_by: Optional[UserContext] = Unassigned()
last_modified_time: Optional[datetime.datetime] = Unassigned()
last_modified_by: Optional[UserContext] = Unassigned()
lineage_group_arn: Optional[str] = Unassigned()
def get_name(self) -> str:
attributes = vars(self)
resource_name = "context_name"
resource_name_split = resource_name.split("_")
attribute_name_candidates = []
l = len(resource_name_split)
for i in range(0, l):
attribute_name_candidates.append("_".join(resource_name_split[i:l]))
for attribute, value in attributes.items():
if attribute == "name" or attribute in attribute_name_candidates:
return value
logger.error("Name attribute not found for object context")
return None
[docs]
@classmethod
@Base.add_validate_call
def create(
cls,
context_name: str,
source: ContextSource,
context_type: str,
description: Optional[str] = Unassigned(),
properties: Optional[Dict[str, str]] = Unassigned(),
tags: Optional[List[Tag]] = Unassigned(),
session: Optional[Session] = None,
region: Optional[str] = None,
) -> Optional["Context"]:
"""
Create a Context resource
Parameters:
context_name: The name of the context. Must be unique to your account in an Amazon Web Services Region.
source: The source type, ID, and URI.
context_type: The context type.
description: The description of the context.
properties: A list of properties to add to the context.
tags: A list of tags to apply to the context.
session: Boto3 session.
region: Region name.
Returns:
The Context resource.
Raises:
botocore.exceptions.ClientError: This exception is raised for AWS service related errors.
The error message and error code can be parsed from the exception as follows:
```
try:
# AWS service call here
except botocore.exceptions.ClientError as e:
error_message = e.response['Error']['Message']
error_code = e.response['Error']['Code']
```
ResourceLimitExceeded: You have exceeded an SageMaker resource limit. For example, you might have too many training jobs created.
ConfigSchemaValidationError: Raised when a configuration file does not adhere to the schema
LocalConfigNotFoundError: Raised when a configuration file is not found in local file system
S3ConfigNotFoundError: Raised when a configuration file is not found in S3
"""
logger.info("Creating context resource.")
client = Base.get_sagemaker_client(
session=session, region_name=region, service_name="sagemaker"
)
operation_input_args = {
"ContextName": context_name,
"Source": source,
"ContextType": context_type,
"Description": description,
"Properties": properties,
"Tags": tags,
}
operation_input_args = Base.populate_chained_attributes(
resource_name="Context", operation_input_args=operation_input_args
)
logger.debug(f"Input request: {operation_input_args}")
# serialize the input request
operation_input_args = serialize(operation_input_args)
logger.debug(f"Serialized input request: {operation_input_args}")
# create the resource
response = client.create_context(**operation_input_args)
logger.debug(f"Response: {response}")
return cls.get(context_name=context_name, session=session, region=region)
[docs]
@classmethod
@Base.add_validate_call
def get(
cls,
context_name: str,
session: Optional[Session] = None,
region: Optional[str] = None,
) -> Optional["Context"]:
"""
Get a Context resource
Parameters:
context_name: The name of the context to describe.
session: Boto3 session.
region: Region name.
Returns:
The Context resource.
Raises:
botocore.exceptions.ClientError: This exception is raised for AWS service related errors.
The error message and error code can be parsed from the exception as follows:
```
try:
# AWS service call here
except botocore.exceptions.ClientError as e:
error_message = e.response['Error']['Message']
error_code = e.response['Error']['Code']
```
ResourceNotFound: Resource being access is not found.
"""
operation_input_args = {
"ContextName": context_name,
}
# serialize the input request
operation_input_args = serialize(operation_input_args)
logger.debug(f"Serialized input request: {operation_input_args}")
client = Base.get_sagemaker_client(
session=session, region_name=region, service_name="sagemaker"
)
response = client.describe_context(**operation_input_args)
logger.debug(response)
# deserialize the response
transformed_response = transform(response, "DescribeContextResponse")
context = cls(**transformed_response)
return context
[docs]
@Base.add_validate_call
def refresh(
self,
) -> Optional["Context"]:
"""
Refresh a Context resource
Returns:
The Context resource.
Raises:
botocore.exceptions.ClientError: This exception is raised for AWS service related errors.
The error message and error code can be parsed from the exception as follows:
```
try:
# AWS service call here
except botocore.exceptions.ClientError as e:
error_message = e.response['Error']['Message']
error_code = e.response['Error']['Code']
```
ResourceNotFound: Resource being access is not found.
"""
operation_input_args = {
"ContextName": self.context_name,
}
# serialize the input request
operation_input_args = serialize(operation_input_args)
logger.debug(f"Serialized input request: {operation_input_args}")
client = Base.get_sagemaker_client()
response = client.describe_context(**operation_input_args)
# deserialize response and update self
transform(response, "DescribeContextResponse", self)
return self
[docs]
@Base.add_validate_call
def update(
self,
description: Optional[str] = Unassigned(),
properties: Optional[Dict[str, str]] = Unassigned(),
properties_to_remove: Optional[List[str]] = Unassigned(),
) -> Optional["Context"]:
"""
Update a Context resource
Parameters:
properties_to_remove: A list of properties to remove.
Returns:
The Context resource.
Raises:
botocore.exceptions.ClientError: This exception is raised for AWS service related errors.
The error message and error code can be parsed from the exception as follows:
```
try:
# AWS service call here
except botocore.exceptions.ClientError as e:
error_message = e.response['Error']['Message']
error_code = e.response['Error']['Code']
```
ConflictException: There was a conflict when you attempted to modify a SageMaker entity such as an Experiment or Artifact.
ResourceNotFound: Resource being access is not found.
"""
logger.info("Updating context resource.")
client = Base.get_sagemaker_client()
operation_input_args = {
"ContextName": self.context_name,
"Description": description,
"Properties": properties,
"PropertiesToRemove": properties_to_remove,
}
logger.debug(f"Input request: {operation_input_args}")
# serialize the input request
operation_input_args = serialize(operation_input_args)
logger.debug(f"Serialized input request: {operation_input_args}")
# create the resource
response = client.update_context(**operation_input_args)
logger.debug(f"Response: {response}")
self.refresh()
return self
[docs]
@Base.add_validate_call
def delete(
self,
) -> None:
"""
Delete a Context resource
Raises:
botocore.exceptions.ClientError: This exception is raised for AWS service related errors.
The error message and error code can be parsed from the exception as follows:
```
try:
# AWS service call here
except botocore.exceptions.ClientError as e:
error_message = e.response['Error']['Message']
error_code = e.response['Error']['Code']
```
ResourceNotFound: Resource being access is not found.
"""
client = Base.get_sagemaker_client()
operation_input_args = {
"ContextName": self.context_name,
}
# serialize the input request
operation_input_args = serialize(operation_input_args)
logger.debug(f"Serialized input request: {operation_input_args}")
client.delete_context(**operation_input_args)
logger.info(f"Deleting {self.__class__.__name__} - {self.get_name()}")
[docs]
@classmethod
@Base.add_validate_call
def get_all(
cls,
source_uri: Optional[str] = Unassigned(),
context_type: Optional[str] = Unassigned(),
created_after: Optional[datetime.datetime] = Unassigned(),
created_before: Optional[datetime.datetime] = Unassigned(),
sort_by: Optional[str] = Unassigned(),
sort_order: Optional[str] = Unassigned(),
session: Optional[Session] = None,
region: Optional[str] = None,
) -> ResourceIterator["Context"]:
"""
Get all Context resources
Parameters:
source_uri: A filter that returns only contexts with the specified source URI.
context_type: A filter that returns only contexts of the specified type.
created_after: A filter that returns only contexts created on or after the specified time.
created_before: A filter that returns only contexts created on or before the specified time.
sort_by: The property used to sort results. The default value is CreationTime.
sort_order: The sort order. The default value is Descending.
next_token: If the previous call to ListContexts didn't return the full set of contexts, the call returns a token for getting the next set of contexts.
max_results: The maximum number of contexts to return in the response. The default value is 10.
session: Boto3 session.
region: Region name.
Returns:
Iterator for listed Context resources.
Raises:
botocore.exceptions.ClientError: This exception is raised for AWS service related errors.
The error message and error code can be parsed from the exception as follows:
```
try:
# AWS service call here
except botocore.exceptions.ClientError as e:
error_message = e.response['Error']['Message']
error_code = e.response['Error']['Code']
```
ResourceNotFound: Resource being access is not found.
"""
client = Base.get_sagemaker_client(
session=session, region_name=region, service_name="sagemaker"
)
operation_input_args = {
"SourceUri": source_uri,
"ContextType": context_type,
"CreatedAfter": created_after,
"CreatedBefore": created_before,
"SortBy": sort_by,
"SortOrder": sort_order,
}
# serialize the input request
operation_input_args = serialize(operation_input_args)
logger.debug(f"Serialized input request: {operation_input_args}")
return ResourceIterator(
client=client,
list_method="list_contexts",
summaries_key="ContextSummaries",
summary_name="ContextSummary",
resource_cls=Context,
list_method_kwargs=operation_input_args,
)
[docs]
class DataQualityJobDefinition(Base):
"""
Class representing resource DataQualityJobDefinition
Attributes:
job_definition_arn: The Amazon Resource Name (ARN) of the data quality monitoring job definition.
job_definition_name: The name of the data quality monitoring job definition.
creation_time: The time that the data quality monitoring job definition was created.
data_quality_app_specification: Information about the container that runs the data quality monitoring job.
data_quality_job_input: The list of inputs for the data quality monitoring job. Currently endpoints are supported.
data_quality_job_output_config:
job_resources:
role_arn: The Amazon Resource Name (ARN) of an IAM role that Amazon SageMaker can assume to perform tasks on your behalf.
data_quality_baseline_config: The constraints and baselines for the data quality monitoring job definition.
network_config: The networking configuration for the data quality monitoring job.
stopping_condition:
"""
job_definition_name: str
job_definition_arn: Optional[str] = Unassigned()
creation_time: Optional[datetime.datetime] = Unassigned()
data_quality_baseline_config: Optional[DataQualityBaselineConfig] = Unassigned()
data_quality_app_specification: Optional[DataQualityAppSpecification] = Unassigned()
data_quality_job_input: Optional[DataQualityJobInput] = Unassigned()
data_quality_job_output_config: Optional[MonitoringOutputConfig] = Unassigned()
job_resources: Optional[MonitoringResources] = Unassigned()
network_config: Optional[MonitoringNetworkConfig] = Unassigned()
role_arn: Optional[str] = Unassigned()
stopping_condition: Optional[MonitoringStoppingCondition] = Unassigned()
def get_name(self) -> str:
attributes = vars(self)
resource_name = "data_quality_job_definition_name"
resource_name_split = resource_name.split("_")
attribute_name_candidates = []
l = len(resource_name_split)
for i in range(0, l):
attribute_name_candidates.append("_".join(resource_name_split[i:l]))
for attribute, value in attributes.items():
if attribute == "name" or attribute in attribute_name_candidates:
return value
logger.error("Name attribute not found for object data_quality_job_definition")
return None
def populate_inputs_decorator(create_func):
@functools.wraps(create_func)
def wrapper(*args, **kwargs):
config_schema_for_resource = {
"data_quality_job_input": {
"endpoint_input": {
"s3_input_mode": {"type": "string"},
"s3_data_distribution_type": {"type": "string"},
},
"batch_transform_input": {
"data_captured_destination_s3_uri": {"type": "string"},
"s3_input_mode": {"type": "string"},
"s3_data_distribution_type": {"type": "string"},
},
},
"data_quality_job_output_config": {"kms_key_id": {"type": "string"}},
"job_resources": {"cluster_config": {"volume_kms_key_id": {"type": "string"}}},
"role_arn": {"type": "string"},
"data_quality_baseline_config": {
"constraints_resource": {"s3_uri": {"type": "string"}},
"statistics_resource": {"s3_uri": {"type": "string"}},
},
"network_config": {
"vpc_config": {
"security_group_ids": {"type": "array", "items": {"type": "string"}},
"subnets": {"type": "array", "items": {"type": "string"}},
}
},
}
return create_func(
*args,
**Base.get_updated_kwargs_with_configured_attributes(
config_schema_for_resource, "DataQualityJobDefinition", **kwargs
),
)
return wrapper
[docs]
@classmethod
@populate_inputs_decorator
@Base.add_validate_call
def create(
cls,
job_definition_name: str,
data_quality_app_specification: DataQualityAppSpecification,
data_quality_job_input: DataQualityJobInput,
data_quality_job_output_config: MonitoringOutputConfig,
job_resources: MonitoringResources,
role_arn: str,
data_quality_baseline_config: Optional[DataQualityBaselineConfig] = Unassigned(),
network_config: Optional[MonitoringNetworkConfig] = Unassigned(),
stopping_condition: Optional[MonitoringStoppingCondition] = Unassigned(),
tags: Optional[List[Tag]] = Unassigned(),
session: Optional[Session] = None,
region: Optional[str] = None,
) -> Optional["DataQualityJobDefinition"]:
"""
Create a DataQualityJobDefinition resource
Parameters:
job_definition_name: The name for the monitoring job definition.
data_quality_app_specification: Specifies the container that runs the monitoring job.
data_quality_job_input: A list of inputs for the monitoring job. Currently endpoints are supported as monitoring inputs.
data_quality_job_output_config:
job_resources:
role_arn: The Amazon Resource Name (ARN) of an IAM role that Amazon SageMaker can assume to perform tasks on your behalf.
data_quality_baseline_config: Configures the constraints and baselines for the monitoring job.
network_config: Specifies networking configuration for the monitoring job.
stopping_condition:
tags: (Optional) An array of key-value pairs. For more information, see Using Cost Allocation Tags in the Amazon Web Services Billing and Cost Management User Guide.
session: Boto3 session.
region: Region name.
Returns:
The DataQualityJobDefinition resource.
Raises:
botocore.exceptions.ClientError: This exception is raised for AWS service related errors.
The error message and error code can be parsed from the exception as follows:
```
try:
# AWS service call here
except botocore.exceptions.ClientError as e:
error_message = e.response['Error']['Message']
error_code = e.response['Error']['Code']
```
ResourceInUse: Resource being accessed is in use.
ResourceLimitExceeded: You have exceeded an SageMaker resource limit. For example, you might have too many training jobs created.
ConfigSchemaValidationError: Raised when a configuration file does not adhere to the schema
LocalConfigNotFoundError: Raised when a configuration file is not found in local file system
S3ConfigNotFoundError: Raised when a configuration file is not found in S3
"""
logger.info("Creating data_quality_job_definition resource.")
client = Base.get_sagemaker_client(
session=session, region_name=region, service_name="sagemaker"
)
operation_input_args = {
"JobDefinitionName": job_definition_name,
"DataQualityBaselineConfig": data_quality_baseline_config,
"DataQualityAppSpecification": data_quality_app_specification,
"DataQualityJobInput": data_quality_job_input,
"DataQualityJobOutputConfig": data_quality_job_output_config,
"JobResources": job_resources,
"NetworkConfig": network_config,
"RoleArn": role_arn,
"StoppingCondition": stopping_condition,
"Tags": tags,
}
operation_input_args = Base.populate_chained_attributes(
resource_name="DataQualityJobDefinition", operation_input_args=operation_input_args
)
logger.debug(f"Input request: {operation_input_args}")
# serialize the input request
operation_input_args = serialize(operation_input_args)
logger.debug(f"Serialized input request: {operation_input_args}")
# create the resource
response = client.create_data_quality_job_definition(**operation_input_args)
logger.debug(f"Response: {response}")
return cls.get(job_definition_name=job_definition_name, session=session, region=region)
[docs]
@classmethod
@Base.add_validate_call
def get(
cls,
job_definition_name: str,
session: Optional[Session] = None,
region: Optional[str] = None,
) -> Optional["DataQualityJobDefinition"]:
"""
Get a DataQualityJobDefinition resource
Parameters:
job_definition_name: The name of the data quality monitoring job definition to describe.
session: Boto3 session.
region: Region name.
Returns:
The DataQualityJobDefinition resource.
Raises:
botocore.exceptions.ClientError: This exception is raised for AWS service related errors.
The error message and error code can be parsed from the exception as follows:
```
try:
# AWS service call here
except botocore.exceptions.ClientError as e:
error_message = e.response['Error']['Message']
error_code = e.response['Error']['Code']
```
ResourceNotFound: Resource being access is not found.
"""
operation_input_args = {
"JobDefinitionName": job_definition_name,
}
# serialize the input request
operation_input_args = serialize(operation_input_args)
logger.debug(f"Serialized input request: {operation_input_args}")
client = Base.get_sagemaker_client(
session=session, region_name=region, service_name="sagemaker"
)
response = client.describe_data_quality_job_definition(**operation_input_args)
logger.debug(response)
# deserialize the response
transformed_response = transform(response, "DescribeDataQualityJobDefinitionResponse")
data_quality_job_definition = cls(**transformed_response)
return data_quality_job_definition
[docs]
@Base.add_validate_call
def refresh(
self,
) -> Optional["DataQualityJobDefinition"]:
"""
Refresh a DataQualityJobDefinition resource
Returns:
The DataQualityJobDefinition resource.
Raises:
botocore.exceptions.ClientError: This exception is raised for AWS service related errors.
The error message and error code can be parsed from the exception as follows:
```
try:
# AWS service call here
except botocore.exceptions.ClientError as e:
error_message = e.response['Error']['Message']
error_code = e.response['Error']['Code']
```
ResourceNotFound: Resource being access is not found.
"""
operation_input_args = {
"JobDefinitionName": self.job_definition_name,
}
# serialize the input request
operation_input_args = serialize(operation_input_args)
logger.debug(f"Serialized input request: {operation_input_args}")
client = Base.get_sagemaker_client()
response = client.describe_data_quality_job_definition(**operation_input_args)
# deserialize response and update self
transform(response, "DescribeDataQualityJobDefinitionResponse", self)
return self
[docs]
@Base.add_validate_call
def delete(
self,
) -> None:
"""
Delete a DataQualityJobDefinition resource
Raises:
botocore.exceptions.ClientError: This exception is raised for AWS service related errors.
The error message and error code can be parsed from the exception as follows:
```
try:
# AWS service call here
except botocore.exceptions.ClientError as e:
error_message = e.response['Error']['Message']
error_code = e.response['Error']['Code']
```
ResourceNotFound: Resource being access is not found.
"""
client = Base.get_sagemaker_client()
operation_input_args = {
"JobDefinitionName": self.job_definition_name,
}
# serialize the input request
operation_input_args = serialize(operation_input_args)
logger.debug(f"Serialized input request: {operation_input_args}")
client.delete_data_quality_job_definition(**operation_input_args)
logger.info(f"Deleting {self.__class__.__name__} - {self.get_name()}")
[docs]
@classmethod
@Base.add_validate_call
def get_all(
cls,
endpoint_name: Optional[str] = Unassigned(),
sort_by: Optional[str] = Unassigned(),
sort_order: Optional[str] = Unassigned(),
name_contains: Optional[str] = Unassigned(),
creation_time_before: Optional[datetime.datetime] = Unassigned(),
creation_time_after: Optional[datetime.datetime] = Unassigned(),
session: Optional[Session] = None,
region: Optional[str] = None,
) -> ResourceIterator["DataQualityJobDefinition"]:
"""
Get all DataQualityJobDefinition resources
Parameters:
endpoint_name: A filter that lists the data quality job definitions associated with the specified endpoint.
sort_by: The field to sort results by. The default is CreationTime.
sort_order: Whether to sort the results in Ascending or Descending order. The default is Descending.
next_token: If the result of the previous ListDataQualityJobDefinitions request was truncated, the response includes a NextToken. To retrieve the next set of transform jobs, use the token in the next request.>
max_results: The maximum number of data quality monitoring job definitions to return in the response.
name_contains: A string in the data quality monitoring job definition name. This filter returns only data quality monitoring job definitions whose name contains the specified string.
creation_time_before: A filter that returns only data quality monitoring job definitions created before the specified time.
creation_time_after: A filter that returns only data quality monitoring job definitions created after the specified time.
session: Boto3 session.
region: Region name.
Returns:
Iterator for listed DataQualityJobDefinition resources.
Raises:
botocore.exceptions.ClientError: This exception is raised for AWS service related errors.
The error message and error code can be parsed from the exception as follows:
```
try:
# AWS service call here
except botocore.exceptions.ClientError as e:
error_message = e.response['Error']['Message']
error_code = e.response['Error']['Code']
```
"""
client = Base.get_sagemaker_client(
session=session, region_name=region, service_name="sagemaker"
)
operation_input_args = {
"EndpointName": endpoint_name,
"SortBy": sort_by,
"SortOrder": sort_order,
"NameContains": name_contains,
"CreationTimeBefore": creation_time_before,
"CreationTimeAfter": creation_time_after,
}
custom_key_mapping = {
"monitoring_job_definition_name": "job_definition_name",
"monitoring_job_definition_arn": "job_definition_arn",
}
# serialize the input request
operation_input_args = serialize(operation_input_args)
logger.debug(f"Serialized input request: {operation_input_args}")
return ResourceIterator(
client=client,
list_method="list_data_quality_job_definitions",
summaries_key="JobDefinitionSummaries",
summary_name="MonitoringJobDefinitionSummary",
resource_cls=DataQualityJobDefinition,
custom_key_mapping=custom_key_mapping,
list_method_kwargs=operation_input_args,
)
[docs]
class Device(Base):
"""
Class representing resource Device
Attributes:
device_name: The unique identifier of the device.
device_fleet_name: The name of the fleet the device belongs to.
registration_time: The timestamp of the last registration or de-reregistration.
device_arn: The Amazon Resource Name (ARN) of the device.
description: A description of the device.
iot_thing_name: The Amazon Web Services Internet of Things (IoT) object thing name associated with the device.
latest_heartbeat: The last heartbeat received from the device.
models: Models on the device.
max_models: The maximum number of models.
next_token: The response from the last list when returning a list large enough to need tokening.
agent_version: Edge Manager agent version.
"""
device_name: str
device_fleet_name: str
device_arn: Optional[str] = Unassigned()
description: Optional[str] = Unassigned()
iot_thing_name: Optional[str] = Unassigned()
registration_time: Optional[datetime.datetime] = Unassigned()
latest_heartbeat: Optional[datetime.datetime] = Unassigned()
models: Optional[List[EdgeModel]] = Unassigned()
max_models: Optional[int] = Unassigned()
next_token: Optional[str] = Unassigned()
agent_version: Optional[str] = Unassigned()
def get_name(self) -> str:
attributes = vars(self)
resource_name = "device_name"
resource_name_split = resource_name.split("_")
attribute_name_candidates = []
l = len(resource_name_split)
for i in range(0, l):
attribute_name_candidates.append("_".join(resource_name_split[i:l]))
for attribute, value in attributes.items():
if attribute == "name" or attribute in attribute_name_candidates:
return value
logger.error("Name attribute not found for object device")
return None
[docs]
@classmethod
@Base.add_validate_call
def get(
cls,
device_name: str,
device_fleet_name: str,
next_token: Optional[str] = Unassigned(),
session: Optional[Session] = None,
region: Optional[str] = None,
) -> Optional["Device"]:
"""
Get a Device resource
Parameters:
device_name: The unique ID of the device.
device_fleet_name: The name of the fleet the devices belong to.
next_token: Next token of device description.
session: Boto3 session.
region: Region name.
Returns:
The Device resource.
Raises:
botocore.exceptions.ClientError: This exception is raised for AWS service related errors.
The error message and error code can be parsed from the exception as follows:
```
try:
# AWS service call here
except botocore.exceptions.ClientError as e:
error_message = e.response['Error']['Message']
error_code = e.response['Error']['Code']
```
ResourceNotFound: Resource being access is not found.
"""
operation_input_args = {
"NextToken": next_token,
"DeviceName": device_name,
"DeviceFleetName": device_fleet_name,
}
# serialize the input request
operation_input_args = serialize(operation_input_args)
logger.debug(f"Serialized input request: {operation_input_args}")
client = Base.get_sagemaker_client(
session=session, region_name=region, service_name="sagemaker"
)
response = client.describe_device(**operation_input_args)
logger.debug(response)
# deserialize the response
transformed_response = transform(response, "DescribeDeviceResponse")
device = cls(**transformed_response)
return device
[docs]
@Base.add_validate_call
def refresh(
self,
) -> Optional["Device"]:
"""
Refresh a Device resource
Returns:
The Device resource.
Raises:
botocore.exceptions.ClientError: This exception is raised for AWS service related errors.
The error message and error code can be parsed from the exception as follows:
```
try:
# AWS service call here
except botocore.exceptions.ClientError as e:
error_message = e.response['Error']['Message']
error_code = e.response['Error']['Code']
```
ResourceNotFound: Resource being access is not found.
"""
operation_input_args = {
"NextToken": self.next_token,
"DeviceName": self.device_name,
"DeviceFleetName": self.device_fleet_name,
}
# serialize the input request
operation_input_args = serialize(operation_input_args)
logger.debug(f"Serialized input request: {operation_input_args}")
client = Base.get_sagemaker_client()
response = client.describe_device(**operation_input_args)
# deserialize response and update self
transform(response, "DescribeDeviceResponse", self)
return self
[docs]
@classmethod
@Base.add_validate_call
def get_all(
cls,
latest_heartbeat_after: Optional[datetime.datetime] = Unassigned(),
model_name: Optional[str] = Unassigned(),
device_fleet_name: Optional[str] = Unassigned(),
session: Optional[Session] = None,
region: Optional[str] = None,
) -> ResourceIterator["Device"]:
"""
Get all Device resources
Parameters:
next_token: The response from the last list when returning a list large enough to need tokening.
max_results: Maximum number of results to select.
latest_heartbeat_after: Select fleets where the job was updated after X
model_name: A filter that searches devices that contains this name in any of their models.
device_fleet_name: Filter for fleets containing this name in their device fleet name.
session: Boto3 session.
region: Region name.
Returns:
Iterator for listed Device resources.
Raises:
botocore.exceptions.ClientError: This exception is raised for AWS service related errors.
The error message and error code can be parsed from the exception as follows:
```
try:
# AWS service call here
except botocore.exceptions.ClientError as e:
error_message = e.response['Error']['Message']
error_code = e.response['Error']['Code']
```
"""
client = Base.get_sagemaker_client(
session=session, region_name=region, service_name="sagemaker"
)
operation_input_args = {
"LatestHeartbeatAfter": latest_heartbeat_after,
"ModelName": model_name,
"DeviceFleetName": device_fleet_name,
}
# serialize the input request
operation_input_args = serialize(operation_input_args)
logger.debug(f"Serialized input request: {operation_input_args}")
return ResourceIterator(
client=client,
list_method="list_devices",
summaries_key="DeviceSummaries",
summary_name="DeviceSummary",
resource_cls=Device,
list_method_kwargs=operation_input_args,
)
[docs]
class DeviceFleet(Base):
"""
Class representing resource DeviceFleet
Attributes:
device_fleet_name: The name of the fleet.
device_fleet_arn: The The Amazon Resource Name (ARN) of the fleet.
output_config: The output configuration for storing sampled data.
creation_time: Timestamp of when the device fleet was created.
last_modified_time: Timestamp of when the device fleet was last updated.
description: A description of the fleet.
role_arn: The Amazon Resource Name (ARN) that has access to Amazon Web Services Internet of Things (IoT).
iot_role_alias: The Amazon Resource Name (ARN) alias created in Amazon Web Services Internet of Things (IoT).
"""
device_fleet_name: str
device_fleet_arn: Optional[str] = Unassigned()
output_config: Optional[EdgeOutputConfig] = Unassigned()
description: Optional[str] = Unassigned()
creation_time: Optional[datetime.datetime] = Unassigned()
last_modified_time: Optional[datetime.datetime] = Unassigned()
role_arn: Optional[str] = Unassigned()
iot_role_alias: Optional[str] = Unassigned()
def get_name(self) -> str:
attributes = vars(self)
resource_name = "device_fleet_name"
resource_name_split = resource_name.split("_")
attribute_name_candidates = []
l = len(resource_name_split)
for i in range(0, l):
attribute_name_candidates.append("_".join(resource_name_split[i:l]))
for attribute, value in attributes.items():
if attribute == "name" or attribute in attribute_name_candidates:
return value
logger.error("Name attribute not found for object device_fleet")
return None
def populate_inputs_decorator(create_func):
@functools.wraps(create_func)
def wrapper(*args, **kwargs):
config_schema_for_resource = {
"output_config": {
"s3_output_location": {"type": "string"},
"kms_key_id": {"type": "string"},
},
"role_arn": {"type": "string"},
"iot_role_alias": {"type": "string"},
}
return create_func(
*args,
**Base.get_updated_kwargs_with_configured_attributes(
config_schema_for_resource, "DeviceFleet", **kwargs
),
)
return wrapper
[docs]
@classmethod
@populate_inputs_decorator
@Base.add_validate_call
def create(
cls,
device_fleet_name: str,
output_config: EdgeOutputConfig,
role_arn: Optional[str] = Unassigned(),
description: Optional[str] = Unassigned(),
tags: Optional[List[Tag]] = Unassigned(),
enable_iot_role_alias: Optional[bool] = Unassigned(),
session: Optional[Session] = None,
region: Optional[str] = None,
) -> Optional["DeviceFleet"]:
"""
Create a DeviceFleet resource
Parameters:
device_fleet_name: The name of the fleet that the device belongs to.
output_config: The output configuration for storing sample data collected by the fleet.
role_arn: The Amazon Resource Name (ARN) that has access to Amazon Web Services Internet of Things (IoT).
description: A description of the fleet.
tags: Creates tags for the specified fleet.
enable_iot_role_alias: Whether to create an Amazon Web Services IoT Role Alias during device fleet creation. The name of the role alias generated will match this pattern: "SageMakerEdge-{DeviceFleetName}". For example, if your device fleet is called "demo-fleet", the name of the role alias will be "SageMakerEdge-demo-fleet".
session: Boto3 session.
region: Region name.
Returns:
The DeviceFleet resource.
Raises:
botocore.exceptions.ClientError: This exception is raised for AWS service related errors.
The error message and error code can be parsed from the exception as follows:
```
try:
# AWS service call here
except botocore.exceptions.ClientError as e:
error_message = e.response['Error']['Message']
error_code = e.response['Error']['Code']
```
ResourceInUse: Resource being accessed is in use.
ResourceLimitExceeded: You have exceeded an SageMaker resource limit. For example, you might have too many training jobs created.
ConfigSchemaValidationError: Raised when a configuration file does not adhere to the schema
LocalConfigNotFoundError: Raised when a configuration file is not found in local file system
S3ConfigNotFoundError: Raised when a configuration file is not found in S3
"""
logger.info("Creating device_fleet resource.")
client = Base.get_sagemaker_client(
session=session, region_name=region, service_name="sagemaker"
)
operation_input_args = {
"DeviceFleetName": device_fleet_name,
"RoleArn": role_arn,
"Description": description,
"OutputConfig": output_config,
"Tags": tags,
"EnableIotRoleAlias": enable_iot_role_alias,
}
operation_input_args = Base.populate_chained_attributes(
resource_name="DeviceFleet", operation_input_args=operation_input_args
)
logger.debug(f"Input request: {operation_input_args}")
# serialize the input request
operation_input_args = serialize(operation_input_args)
logger.debug(f"Serialized input request: {operation_input_args}")
# create the resource
response = client.create_device_fleet(**operation_input_args)
logger.debug(f"Response: {response}")
return cls.get(device_fleet_name=device_fleet_name, session=session, region=region)
[docs]
@classmethod
@Base.add_validate_call
def get(
cls,
device_fleet_name: str,
session: Optional[Session] = None,
region: Optional[str] = None,
) -> Optional["DeviceFleet"]:
"""
Get a DeviceFleet resource
Parameters:
device_fleet_name: The name of the fleet.
session: Boto3 session.
region: Region name.
Returns:
The DeviceFleet resource.
Raises:
botocore.exceptions.ClientError: This exception is raised for AWS service related errors.
The error message and error code can be parsed from the exception as follows:
```
try:
# AWS service call here
except botocore.exceptions.ClientError as e:
error_message = e.response['Error']['Message']
error_code = e.response['Error']['Code']
```
ResourceNotFound: Resource being access is not found.
"""
operation_input_args = {
"DeviceFleetName": device_fleet_name,
}
# serialize the input request
operation_input_args = serialize(operation_input_args)
logger.debug(f"Serialized input request: {operation_input_args}")
client = Base.get_sagemaker_client(
session=session, region_name=region, service_name="sagemaker"
)
response = client.describe_device_fleet(**operation_input_args)
logger.debug(response)
# deserialize the response
transformed_response = transform(response, "DescribeDeviceFleetResponse")
device_fleet = cls(**transformed_response)
return device_fleet
[docs]
@Base.add_validate_call
def refresh(
self,
) -> Optional["DeviceFleet"]:
"""
Refresh a DeviceFleet resource
Returns:
The DeviceFleet resource.
Raises:
botocore.exceptions.ClientError: This exception is raised for AWS service related errors.
The error message and error code can be parsed from the exception as follows:
```
try:
# AWS service call here
except botocore.exceptions.ClientError as e:
error_message = e.response['Error']['Message']
error_code = e.response['Error']['Code']
```
ResourceNotFound: Resource being access is not found.
"""
operation_input_args = {
"DeviceFleetName": self.device_fleet_name,
}
# serialize the input request
operation_input_args = serialize(operation_input_args)
logger.debug(f"Serialized input request: {operation_input_args}")
client = Base.get_sagemaker_client()
response = client.describe_device_fleet(**operation_input_args)
# deserialize response and update self
transform(response, "DescribeDeviceFleetResponse", self)
return self
[docs]
@populate_inputs_decorator
@Base.add_validate_call
def update(
self,
output_config: EdgeOutputConfig,
role_arn: Optional[str] = Unassigned(),
description: Optional[str] = Unassigned(),
enable_iot_role_alias: Optional[bool] = Unassigned(),
) -> Optional["DeviceFleet"]:
"""
Update a DeviceFleet resource
Parameters:
enable_iot_role_alias: Whether to create an Amazon Web Services IoT Role Alias during device fleet creation. The name of the role alias generated will match this pattern: "SageMakerEdge-{DeviceFleetName}". For example, if your device fleet is called "demo-fleet", the name of the role alias will be "SageMakerEdge-demo-fleet".
Returns:
The DeviceFleet resource.
Raises:
botocore.exceptions.ClientError: This exception is raised for AWS service related errors.
The error message and error code can be parsed from the exception as follows:
```
try:
# AWS service call here
except botocore.exceptions.ClientError as e:
error_message = e.response['Error']['Message']
error_code = e.response['Error']['Code']
```
ResourceInUse: Resource being accessed is in use.
"""
logger.info("Updating device_fleet resource.")
client = Base.get_sagemaker_client()
operation_input_args = {
"DeviceFleetName": self.device_fleet_name,
"RoleArn": role_arn,
"Description": description,
"OutputConfig": output_config,
"EnableIotRoleAlias": enable_iot_role_alias,
}
logger.debug(f"Input request: {operation_input_args}")
# serialize the input request
operation_input_args = serialize(operation_input_args)
logger.debug(f"Serialized input request: {operation_input_args}")
# create the resource
response = client.update_device_fleet(**operation_input_args)
logger.debug(f"Response: {response}")
self.refresh()
return self
[docs]
@Base.add_validate_call
def delete(
self,
) -> None:
"""
Delete a DeviceFleet resource
Raises:
botocore.exceptions.ClientError: This exception is raised for AWS service related errors.
The error message and error code can be parsed from the exception as follows:
```
try:
# AWS service call here
except botocore.exceptions.ClientError as e:
error_message = e.response['Error']['Message']
error_code = e.response['Error']['Code']
```
ResourceInUse: Resource being accessed is in use.
"""
client = Base.get_sagemaker_client()
operation_input_args = {
"DeviceFleetName": self.device_fleet_name,
}
# serialize the input request
operation_input_args = serialize(operation_input_args)
logger.debug(f"Serialized input request: {operation_input_args}")
client.delete_device_fleet(**operation_input_args)
logger.info(f"Deleting {self.__class__.__name__} - {self.get_name()}")
[docs]
@classmethod
@Base.add_validate_call
def get_all(
cls,
creation_time_after: Optional[datetime.datetime] = Unassigned(),
creation_time_before: Optional[datetime.datetime] = Unassigned(),
last_modified_time_after: Optional[datetime.datetime] = Unassigned(),
last_modified_time_before: Optional[datetime.datetime] = Unassigned(),
name_contains: Optional[str] = Unassigned(),
sort_by: Optional[str] = Unassigned(),
sort_order: Optional[str] = Unassigned(),
session: Optional[Session] = None,
region: Optional[str] = None,
) -> ResourceIterator["DeviceFleet"]:
"""
Get all DeviceFleet resources
Parameters:
next_token: The response from the last list when returning a list large enough to need tokening.
max_results: The maximum number of results to select.
creation_time_after: Filter fleets where packaging job was created after specified time.
creation_time_before: Filter fleets where the edge packaging job was created before specified time.
last_modified_time_after: Select fleets where the job was updated after X
last_modified_time_before: Select fleets where the job was updated before X
name_contains: Filter for fleets containing this name in their fleet device name.
sort_by: The column to sort by.
sort_order: What direction to sort in.
session: Boto3 session.
region: Region name.
Returns:
Iterator for listed DeviceFleet resources.
Raises:
botocore.exceptions.ClientError: This exception is raised for AWS service related errors.
The error message and error code can be parsed from the exception as follows:
```
try:
# AWS service call here
except botocore.exceptions.ClientError as e:
error_message = e.response['Error']['Message']
error_code = e.response['Error']['Code']
```
"""
client = Base.get_sagemaker_client(
session=session, region_name=region, service_name="sagemaker"
)
operation_input_args = {
"CreationTimeAfter": creation_time_after,
"CreationTimeBefore": creation_time_before,
"LastModifiedTimeAfter": last_modified_time_after,
"LastModifiedTimeBefore": last_modified_time_before,
"NameContains": name_contains,
"SortBy": sort_by,
"SortOrder": sort_order,
}
# serialize the input request
operation_input_args = serialize(operation_input_args)
logger.debug(f"Serialized input request: {operation_input_args}")
return ResourceIterator(
client=client,
list_method="list_device_fleets",
summaries_key="DeviceFleetSummaries",
summary_name="DeviceFleetSummary",
resource_cls=DeviceFleet,
list_method_kwargs=operation_input_args,
)
[docs]
@Base.add_validate_call
def deregister_devices(
self,
device_names: List[str],
session: Optional[Session] = None,
region: Optional[str] = None,
) -> None:
"""
Deregisters the specified devices.
Parameters:
device_names: The unique IDs of the devices.
session: Boto3 session.
region: Region name.
Raises:
botocore.exceptions.ClientError: This exception is raised for AWS service related errors.
The error message and error code can be parsed from the exception as follows:
```
try:
# AWS service call here
except botocore.exceptions.ClientError as e:
error_message = e.response['Error']['Message']
error_code = e.response['Error']['Code']
```
"""
operation_input_args = {
"DeviceFleetName": self.device_fleet_name,
"DeviceNames": device_names,
}
# serialize the input request
operation_input_args = serialize(operation_input_args)
logger.debug(f"Serialized input request: {operation_input_args}")
client = Base.get_sagemaker_client(
session=session, region_name=region, service_name="sagemaker"
)
logger.debug(f"Calling deregister_devices API")
response = client.deregister_devices(**operation_input_args)
logger.debug(f"Response: {response}")
[docs]
@Base.add_validate_call
def get_report(
self,
session: Optional[Session] = None,
region: Optional[str] = None,
) -> Optional[GetDeviceFleetReportResponse]:
"""
Describes a fleet.
Parameters:
session: Boto3 session.
region: Region name.
Returns:
GetDeviceFleetReportResponse
Raises:
botocore.exceptions.ClientError: This exception is raised for AWS service related errors.
The error message and error code can be parsed from the exception as follows:
```
try:
# AWS service call here
except botocore.exceptions.ClientError as e:
error_message = e.response['Error']['Message']
error_code = e.response['Error']['Code']
```
"""
operation_input_args = {
"DeviceFleetName": self.device_fleet_name,
}
# serialize the input request
operation_input_args = serialize(operation_input_args)
logger.debug(f"Serialized input request: {operation_input_args}")
client = Base.get_sagemaker_client(
session=session, region_name=region, service_name="sagemaker"
)
logger.debug(f"Calling get_device_fleet_report API")
response = client.get_device_fleet_report(**operation_input_args)
logger.debug(f"Response: {response}")
transformed_response = transform(response, "GetDeviceFleetReportResponse")
return GetDeviceFleetReportResponse(**transformed_response)
[docs]
@Base.add_validate_call
def register_devices(
self,
devices: List[Device],
tags: Optional[List[Tag]] = Unassigned(),
session: Optional[Session] = None,
region: Optional[str] = None,
) -> None:
"""
Register devices.
Parameters:
devices: A list of devices to register with SageMaker Edge Manager.
tags: The tags associated with devices.
session: Boto3 session.
region: Region name.
Raises:
botocore.exceptions.ClientError: This exception is raised for AWS service related errors.
The error message and error code can be parsed from the exception as follows:
```
try:
# AWS service call here
except botocore.exceptions.ClientError as e:
error_message = e.response['Error']['Message']
error_code = e.response['Error']['Code']
```
ResourceLimitExceeded: You have exceeded an SageMaker resource limit. For example, you might have too many training jobs created.
"""
operation_input_args = {
"DeviceFleetName": self.device_fleet_name,
"Devices": devices,
"Tags": tags,
}
# serialize the input request
operation_input_args = serialize(operation_input_args)
logger.debug(f"Serialized input request: {operation_input_args}")
client = Base.get_sagemaker_client(
session=session, region_name=region, service_name="sagemaker"
)
logger.debug(f"Calling register_devices API")
response = client.register_devices(**operation_input_args)
logger.debug(f"Response: {response}")
[docs]
@Base.add_validate_call
def update_devices(
self,
devices: List[Device],
session: Optional[Session] = None,
region: Optional[str] = None,
) -> None:
"""
Updates one or more devices in a fleet.
Parameters:
devices: List of devices to register with Edge Manager agent.
session: Boto3 session.
region: Region name.
Raises:
botocore.exceptions.ClientError: This exception is raised for AWS service related errors.
The error message and error code can be parsed from the exception as follows:
```
try:
# AWS service call here
except botocore.exceptions.ClientError as e:
error_message = e.response['Error']['Message']
error_code = e.response['Error']['Code']
```
"""
operation_input_args = {
"DeviceFleetName": self.device_fleet_name,
"Devices": devices,
}
# serialize the input request
operation_input_args = serialize(operation_input_args)
logger.debug(f"Serialized input request: {operation_input_args}")
client = Base.get_sagemaker_client(
session=session, region_name=region, service_name="sagemaker"
)
logger.debug(f"Calling update_devices API")
response = client.update_devices(**operation_input_args)
logger.debug(f"Response: {response}")
[docs]
class Domain(Base):
"""
Class representing resource Domain
Attributes:
domain_arn: The domain's Amazon Resource Name (ARN).
domain_id: The domain ID.
domain_name: The domain name.
home_efs_file_system_id: The ID of the Amazon Elastic File System managed by this Domain.
single_sign_on_managed_application_instance_id: The IAM Identity Center managed application instance ID.
single_sign_on_application_arn: The ARN of the application managed by SageMaker in IAM Identity Center. This value is only returned for domains created after October 1, 2023.
status: The status.
creation_time: The creation time.
last_modified_time: The last modified time.
failure_reason: The failure reason.
security_group_id_for_domain_boundary: The ID of the security group that authorizes traffic between the RSessionGateway apps and the RStudioServerPro app.
auth_mode: The domain's authentication mode.
default_user_settings: Settings which are applied to UserProfiles in this domain if settings are not explicitly specified in a given UserProfile.
domain_settings: A collection of Domain settings.
app_network_access_type: Specifies the VPC used for non-EFS traffic. The default value is PublicInternetOnly. PublicInternetOnly - Non-EFS traffic is through a VPC managed by Amazon SageMaker, which allows direct internet access VpcOnly - All traffic is through the specified VPC and subnets
home_efs_file_system_kms_key_id: Use KmsKeyId.
subnet_ids: The VPC subnets that the domain uses for communication.
url: The domain's URL.
vpc_id: The ID of the Amazon Virtual Private Cloud (VPC) that the domain uses for communication.
kms_key_id: The Amazon Web Services KMS customer managed key used to encrypt the EFS volume attached to the domain.
app_security_group_management: The entity that creates and manages the required security groups for inter-app communication in VPCOnly mode. Required when CreateDomain.AppNetworkAccessType is VPCOnly and DomainSettings.RStudioServerProDomainSettings.DomainExecutionRoleArn is provided.
tag_propagation: Indicates whether custom tag propagation is supported for the domain.
default_space_settings: The default settings for shared spaces that users create in the domain.
"""
domain_id: str
domain_arn: Optional[str] = Unassigned()
domain_name: Optional[str] = Unassigned()
home_efs_file_system_id: Optional[str] = Unassigned()
single_sign_on_managed_application_instance_id: Optional[str] = Unassigned()
single_sign_on_application_arn: Optional[str] = Unassigned()
status: Optional[str] = Unassigned()
creation_time: Optional[datetime.datetime] = Unassigned()
last_modified_time: Optional[datetime.datetime] = Unassigned()
failure_reason: Optional[str] = Unassigned()
security_group_id_for_domain_boundary: Optional[str] = Unassigned()
auth_mode: Optional[str] = Unassigned()
default_user_settings: Optional[UserSettings] = Unassigned()
domain_settings: Optional[DomainSettings] = Unassigned()
app_network_access_type: Optional[str] = Unassigned()
home_efs_file_system_kms_key_id: Optional[str] = Unassigned()
subnet_ids: Optional[List[str]] = Unassigned()
url: Optional[str] = Unassigned()
vpc_id: Optional[str] = Unassigned()
kms_key_id: Optional[str] = Unassigned()
app_security_group_management: Optional[str] = Unassigned()
tag_propagation: Optional[str] = Unassigned()
default_space_settings: Optional[DefaultSpaceSettings] = Unassigned()
def get_name(self) -> str:
attributes = vars(self)
resource_name = "domain_name"
resource_name_split = resource_name.split("_")
attribute_name_candidates = []
l = len(resource_name_split)
for i in range(0, l):
attribute_name_candidates.append("_".join(resource_name_split[i:l]))
for attribute, value in attributes.items():
if attribute == "name" or attribute in attribute_name_candidates:
return value
logger.error("Name attribute not found for object domain")
return None
def populate_inputs_decorator(create_func):
@functools.wraps(create_func)
def wrapper(*args, **kwargs):
config_schema_for_resource = {
"security_group_id_for_domain_boundary": {"type": "string"},
"default_user_settings": {
"execution_role": {"type": "string"},
"security_groups": {"type": "array", "items": {"type": "string"}},
"sharing_settings": {
"s3_output_path": {"type": "string"},
"s3_kms_key_id": {"type": "string"},
},
"canvas_app_settings": {
"time_series_forecasting_settings": {
"amazon_forecast_role_arn": {"type": "string"}
},
"model_register_settings": {
"cross_account_model_register_role_arn": {"type": "string"}
},
"workspace_settings": {
"s3_artifact_path": {"type": "string"},
"s3_kms_key_id": {"type": "string"},
},
"generative_ai_settings": {"amazon_bedrock_role_arn": {"type": "string"}},
"emr_serverless_settings": {"execution_role_arn": {"type": "string"}},
},
"jupyter_lab_app_settings": {
"emr_settings": {
"assumable_role_arns": {"type": "array", "items": {"type": "string"}},
"execution_role_arns": {"type": "array", "items": {"type": "string"}},
}
},
},
"domain_settings": {
"security_group_ids": {"type": "array", "items": {"type": "string"}},
"r_studio_server_pro_domain_settings": {
"domain_execution_role_arn": {"type": "string"}
},
"execution_role_identity_config": {"type": "string"},
},
"home_efs_file_system_kms_key_id": {"type": "string"},
"subnet_ids": {"type": "array", "items": {"type": "string"}},
"kms_key_id": {"type": "string"},
"app_security_group_management": {"type": "string"},
"default_space_settings": {
"execution_role": {"type": "string"},
"security_groups": {"type": "array", "items": {"type": "string"}},
"jupyter_lab_app_settings": {
"emr_settings": {
"assumable_role_arns": {"type": "array", "items": {"type": "string"}},
"execution_role_arns": {"type": "array", "items": {"type": "string"}},
}
},
},
}
return create_func(
*args,
**Base.get_updated_kwargs_with_configured_attributes(
config_schema_for_resource, "Domain", **kwargs
),
)
return wrapper
[docs]
@classmethod
@populate_inputs_decorator
@Base.add_validate_call
def create(
cls,
domain_name: str,
auth_mode: str,
default_user_settings: UserSettings,
subnet_ids: List[str],
vpc_id: str,
domain_settings: Optional[DomainSettings] = Unassigned(),
tags: Optional[List[Tag]] = Unassigned(),
app_network_access_type: Optional[str] = Unassigned(),
home_efs_file_system_kms_key_id: Optional[str] = Unassigned(),
kms_key_id: Optional[str] = Unassigned(),
app_security_group_management: Optional[str] = Unassigned(),
tag_propagation: Optional[str] = Unassigned(),
default_space_settings: Optional[DefaultSpaceSettings] = Unassigned(),
session: Optional[Session] = None,
region: Optional[str] = None,
) -> Optional["Domain"]:
"""
Create a Domain resource
Parameters:
domain_name: A name for the domain.
auth_mode: The mode of authentication that members use to access the domain.
default_user_settings: The default settings to use to create a user profile when UserSettings isn't specified in the call to the CreateUserProfile API. SecurityGroups is aggregated when specified in both calls. For all other settings in UserSettings, the values specified in CreateUserProfile take precedence over those specified in CreateDomain.
subnet_ids: The VPC subnets that the domain uses for communication.
vpc_id: The ID of the Amazon Virtual Private Cloud (VPC) that the domain uses for communication.
domain_settings: A collection of Domain settings.
tags: Tags to associated with the Domain. Each tag consists of a key and an optional value. Tag keys must be unique per resource. Tags are searchable using the Search API. Tags that you specify for the Domain are also added to all Apps that the Domain launches.
app_network_access_type: Specifies the VPC used for non-EFS traffic. The default value is PublicInternetOnly. PublicInternetOnly - Non-EFS traffic is through a VPC managed by Amazon SageMaker, which allows direct internet access VpcOnly - All traffic is through the specified VPC and subnets
home_efs_file_system_kms_key_id: Use KmsKeyId.
kms_key_id: SageMaker uses Amazon Web Services KMS to encrypt EFS and EBS volumes attached to the domain with an Amazon Web Services managed key by default. For more control, specify a customer managed key.
app_security_group_management: The entity that creates and manages the required security groups for inter-app communication in VPCOnly mode. Required when CreateDomain.AppNetworkAccessType is VPCOnly and DomainSettings.RStudioServerProDomainSettings.DomainExecutionRoleArn is provided. If setting up the domain for use with RStudio, this value must be set to Service.
tag_propagation: Indicates whether custom tag propagation is supported for the domain. Defaults to DISABLED.
default_space_settings: The default settings for shared spaces that users create in the domain.
session: Boto3 session.
region: Region name.
Returns:
The Domain resource.
Raises:
botocore.exceptions.ClientError: This exception is raised for AWS service related errors.
The error message and error code can be parsed from the exception as follows:
```
try:
# AWS service call here
except botocore.exceptions.ClientError as e:
error_message = e.response['Error']['Message']
error_code = e.response['Error']['Code']
```
ResourceInUse: Resource being accessed is in use.
ResourceLimitExceeded: You have exceeded an SageMaker resource limit. For example, you might have too many training jobs created.
ConfigSchemaValidationError: Raised when a configuration file does not adhere to the schema
LocalConfigNotFoundError: Raised when a configuration file is not found in local file system
S3ConfigNotFoundError: Raised when a configuration file is not found in S3
"""
logger.info("Creating domain resource.")
client = Base.get_sagemaker_client(
session=session, region_name=region, service_name="sagemaker"
)
operation_input_args = {
"DomainName": domain_name,
"AuthMode": auth_mode,
"DefaultUserSettings": default_user_settings,
"DomainSettings": domain_settings,
"SubnetIds": subnet_ids,
"VpcId": vpc_id,
"Tags": tags,
"AppNetworkAccessType": app_network_access_type,
"HomeEfsFileSystemKmsKeyId": home_efs_file_system_kms_key_id,
"KmsKeyId": kms_key_id,
"AppSecurityGroupManagement": app_security_group_management,
"TagPropagation": tag_propagation,
"DefaultSpaceSettings": default_space_settings,
}
operation_input_args = Base.populate_chained_attributes(
resource_name="Domain", operation_input_args=operation_input_args
)
logger.debug(f"Input request: {operation_input_args}")
# serialize the input request
operation_input_args = serialize(operation_input_args)
logger.debug(f"Serialized input request: {operation_input_args}")
# create the resource
response = client.create_domain(**operation_input_args)
logger.debug(f"Response: {response}")
return cls.get(domain_id=response["DomainId"], session=session, region=region)
[docs]
@classmethod
@Base.add_validate_call
def get(
cls,
domain_id: str,
session: Optional[Session] = None,
region: Optional[str] = None,
) -> Optional["Domain"]:
"""
Get a Domain resource
Parameters:
domain_id: The domain ID.
session: Boto3 session.
region: Region name.
Returns:
The Domain resource.
Raises:
botocore.exceptions.ClientError: This exception is raised for AWS service related errors.
The error message and error code can be parsed from the exception as follows:
```
try:
# AWS service call here
except botocore.exceptions.ClientError as e:
error_message = e.response['Error']['Message']
error_code = e.response['Error']['Code']
```
ResourceNotFound: Resource being access is not found.
"""
operation_input_args = {
"DomainId": domain_id,
}
# serialize the input request
operation_input_args = serialize(operation_input_args)
logger.debug(f"Serialized input request: {operation_input_args}")
client = Base.get_sagemaker_client(
session=session, region_name=region, service_name="sagemaker"
)
response = client.describe_domain(**operation_input_args)
logger.debug(response)
# deserialize the response
transformed_response = transform(response, "DescribeDomainResponse")
domain = cls(**transformed_response)
return domain
[docs]
@Base.add_validate_call
def refresh(
self,
) -> Optional["Domain"]:
"""
Refresh a Domain resource
Returns:
The Domain resource.
Raises:
botocore.exceptions.ClientError: This exception is raised for AWS service related errors.
The error message and error code can be parsed from the exception as follows:
```
try:
# AWS service call here
except botocore.exceptions.ClientError as e:
error_message = e.response['Error']['Message']
error_code = e.response['Error']['Code']
```
ResourceNotFound: Resource being access is not found.
"""
operation_input_args = {
"DomainId": self.domain_id,
}
# serialize the input request
operation_input_args = serialize(operation_input_args)
logger.debug(f"Serialized input request: {operation_input_args}")
client = Base.get_sagemaker_client()
response = client.describe_domain(**operation_input_args)
# deserialize response and update self
transform(response, "DescribeDomainResponse", self)
return self
[docs]
@populate_inputs_decorator
@Base.add_validate_call
def update(
self,
default_user_settings: Optional[UserSettings] = Unassigned(),
domain_settings_for_update: Optional[DomainSettingsForUpdate] = Unassigned(),
app_security_group_management: Optional[str] = Unassigned(),
default_space_settings: Optional[DefaultSpaceSettings] = Unassigned(),
subnet_ids: Optional[List[str]] = Unassigned(),
app_network_access_type: Optional[str] = Unassigned(),
tag_propagation: Optional[str] = Unassigned(),
) -> Optional["Domain"]:
"""
Update a Domain resource
Parameters:
domain_settings_for_update: A collection of DomainSettings configuration values to update.
Returns:
The Domain resource.
Raises:
botocore.exceptions.ClientError: This exception is raised for AWS service related errors.
The error message and error code can be parsed from the exception as follows:
```
try:
# AWS service call here
except botocore.exceptions.ClientError as e:
error_message = e.response['Error']['Message']
error_code = e.response['Error']['Code']
```
ResourceInUse: Resource being accessed is in use.
ResourceLimitExceeded: You have exceeded an SageMaker resource limit. For example, you might have too many training jobs created.
ResourceNotFound: Resource being access is not found.
"""
logger.info("Updating domain resource.")
client = Base.get_sagemaker_client()
operation_input_args = {
"DomainId": self.domain_id,
"DefaultUserSettings": default_user_settings,
"DomainSettingsForUpdate": domain_settings_for_update,
"AppSecurityGroupManagement": app_security_group_management,
"DefaultSpaceSettings": default_space_settings,
"SubnetIds": subnet_ids,
"AppNetworkAccessType": app_network_access_type,
"TagPropagation": tag_propagation,
}
logger.debug(f"Input request: {operation_input_args}")
# serialize the input request
operation_input_args = serialize(operation_input_args)
logger.debug(f"Serialized input request: {operation_input_args}")
# create the resource
response = client.update_domain(**operation_input_args)
logger.debug(f"Response: {response}")
self.refresh()
return self
[docs]
@Base.add_validate_call
def delete(
self,
retention_policy: Optional[RetentionPolicy] = Unassigned(),
) -> None:
"""
Delete a Domain resource
Raises:
botocore.exceptions.ClientError: This exception is raised for AWS service related errors.
The error message and error code can be parsed from the exception as follows:
```
try:
# AWS service call here
except botocore.exceptions.ClientError as e:
error_message = e.response['Error']['Message']
error_code = e.response['Error']['Code']
```
ResourceInUse: Resource being accessed is in use.
ResourceNotFound: Resource being access is not found.
"""
client = Base.get_sagemaker_client()
operation_input_args = {
"DomainId": self.domain_id,
"RetentionPolicy": retention_policy,
}
# serialize the input request
operation_input_args = serialize(operation_input_args)
logger.debug(f"Serialized input request: {operation_input_args}")
client.delete_domain(**operation_input_args)
logger.info(f"Deleting {self.__class__.__name__} - {self.get_name()}")
[docs]
@Base.add_validate_call
def wait_for_status(
self,
target_status: Literal[
"Deleting",
"Failed",
"InService",
"Pending",
"Updating",
"Update_Failed",
"Delete_Failed",
],
poll: int = 5,
timeout: Optional[int] = None,
) -> None:
"""
Wait for a Domain resource to reach certain status.
Parameters:
target_status: The status to wait for.
poll: The number of seconds to wait between each poll.
timeout: The maximum number of seconds to wait before timing out.
Raises:
TimeoutExceededError: If the resource does not reach a terminal state before the timeout.
FailedStatusError: If the resource reaches a failed state.
WaiterError: Raised when an error occurs while waiting.
"""
start_time = time.time()
progress = Progress(
SpinnerColumn("bouncingBar"),
TextColumn("{task.description}"),
TimeElapsedColumn(),
)
progress.add_task(f"Waiting for Domain to reach [bold]{target_status} status...")
status = Status("Current status:")
with Live(
Panel(
Group(progress, status),
title="Wait Log Panel",
border_style=Style(color=Color.BLUE.value),
),
transient=True,
):
while True:
self.refresh()
current_status = self.status
status.update(f"Current status: [bold]{current_status}")
if target_status == current_status:
logger.info(f"Final Resource Status: [bold]{current_status}")
return
if "failed" in current_status.lower():
raise FailedStatusError(
resource_type="Domain", status=current_status, reason=self.failure_reason
)
if timeout is not None and time.time() - start_time >= timeout:
raise TimeoutExceededError(resouce_type="Domain", status=current_status)
time.sleep(poll)
[docs]
@Base.add_validate_call
def wait_for_delete(
self,
poll: int = 5,
timeout: Optional[int] = None,
) -> None:
"""
Wait for a Domain resource to be deleted.
Parameters:
poll: The number of seconds to wait between each poll.
timeout: The maximum number of seconds to wait before timing out.
Raises:
botocore.exceptions.ClientError: This exception is raised for AWS service related errors.
The error message and error code can be parsed from the exception as follows:
```
try:
# AWS service call here
except botocore.exceptions.ClientError as e:
error_message = e.response['Error']['Message']
error_code = e.response['Error']['Code']
```
TimeoutExceededError: If the resource does not reach a terminal state before the timeout.
DeleteFailedStatusError: If the resource reaches a failed state.
WaiterError: Raised when an error occurs while waiting.
"""
start_time = time.time()
progress = Progress(
SpinnerColumn("bouncingBar"),
TextColumn("{task.description}"),
TimeElapsedColumn(),
)
progress.add_task("Waiting for Domain to be deleted...")
status = Status("Current status:")
with Live(
Panel(
Group(progress, status),
title="Wait Log Panel",
border_style=Style(color=Color.BLUE.value),
)
):
while True:
try:
self.refresh()
current_status = self.status
status.update(f"Current status: [bold]{current_status}")
if (
"delete_failed" in current_status.lower()
or "deletefailed" in current_status.lower()
):
raise DeleteFailedStatusError(
resource_type="Domain", reason=self.failure_reason
)
if timeout is not None and time.time() - start_time >= timeout:
raise TimeoutExceededError(resouce_type="Domain", status=current_status)
except botocore.exceptions.ClientError as e:
error_code = e.response["Error"]["Code"]
if "ResourceNotFound" in error_code or "ValidationException" in error_code:
logger.info("Resource was not found. It may have been deleted.")
return
raise e
time.sleep(poll)
[docs]
@classmethod
@Base.add_validate_call
def get_all(
cls,
session: Optional[Session] = None,
region: Optional[str] = None,
) -> ResourceIterator["Domain"]:
"""
Get all Domain resources.
Parameters:
session: Boto3 session.
region: Region name.
Returns:
Iterator for listed Domain resources.
"""
client = Base.get_sagemaker_client(
session=session, region_name=region, service_name="sagemaker"
)
return ResourceIterator(
client=client,
list_method="list_domains",
summaries_key="Domains",
summary_name="DomainDetails",
resource_cls=Domain,
)
[docs]
class EdgeDeploymentPlan(Base):
"""
Class representing resource EdgeDeploymentPlan
Attributes:
edge_deployment_plan_arn: The ARN of edge deployment plan.
edge_deployment_plan_name: The name of the edge deployment plan.
model_configs: List of models associated with the edge deployment plan.
device_fleet_name: The device fleet used for this edge deployment plan.
stages: List of stages in the edge deployment plan.
edge_deployment_success: The number of edge devices with the successful deployment.
edge_deployment_pending: The number of edge devices yet to pick up deployment, or in progress.
edge_deployment_failed: The number of edge devices that failed the deployment.
next_token: Token to use when calling the next set of stages in the edge deployment plan.
creation_time: The time when the edge deployment plan was created.
last_modified_time: The time when the edge deployment plan was last updated.
"""
edge_deployment_plan_name: str
edge_deployment_plan_arn: Optional[str] = Unassigned()
model_configs: Optional[List[EdgeDeploymentModelConfig]] = Unassigned()
device_fleet_name: Optional[str] = Unassigned()
edge_deployment_success: Optional[int] = Unassigned()
edge_deployment_pending: Optional[int] = Unassigned()
edge_deployment_failed: Optional[int] = Unassigned()
stages: Optional[List[DeploymentStageStatusSummary]] = Unassigned()
next_token: Optional[str] = Unassigned()
creation_time: Optional[datetime.datetime] = Unassigned()
last_modified_time: Optional[datetime.datetime] = Unassigned()
def get_name(self) -> str:
attributes = vars(self)
resource_name = "edge_deployment_plan_name"
resource_name_split = resource_name.split("_")
attribute_name_candidates = []
l = len(resource_name_split)
for i in range(0, l):
attribute_name_candidates.append("_".join(resource_name_split[i:l]))
for attribute, value in attributes.items():
if attribute == "name" or attribute in attribute_name_candidates:
return value
logger.error("Name attribute not found for object edge_deployment_plan")
return None
[docs]
@classmethod
@Base.add_validate_call
def create(
cls,
edge_deployment_plan_name: str,
model_configs: List[EdgeDeploymentModelConfig],
device_fleet_name: Union[str, object],
stages: Optional[List[DeploymentStage]] = Unassigned(),
tags: Optional[List[Tag]] = Unassigned(),
session: Optional[Session] = None,
region: Optional[str] = None,
) -> Optional["EdgeDeploymentPlan"]:
"""
Create a EdgeDeploymentPlan resource
Parameters:
edge_deployment_plan_name: The name of the edge deployment plan.
model_configs: List of models associated with the edge deployment plan.
device_fleet_name: The device fleet used for this edge deployment plan.
stages: List of stages of the edge deployment plan. The number of stages is limited to 10 per deployment.
tags: List of tags with which to tag the edge deployment plan.
session: Boto3 session.
region: Region name.
Returns:
The EdgeDeploymentPlan resource.
Raises:
botocore.exceptions.ClientError: This exception is raised for AWS service related errors.
The error message and error code can be parsed from the exception as follows:
```
try:
# AWS service call here
except botocore.exceptions.ClientError as e:
error_message = e.response['Error']['Message']
error_code = e.response['Error']['Code']
```
ResourceLimitExceeded: You have exceeded an SageMaker resource limit. For example, you might have too many training jobs created.
ConfigSchemaValidationError: Raised when a configuration file does not adhere to the schema
LocalConfigNotFoundError: Raised when a configuration file is not found in local file system
S3ConfigNotFoundError: Raised when a configuration file is not found in S3
"""
logger.info("Creating edge_deployment_plan resource.")
client = Base.get_sagemaker_client(
session=session, region_name=region, service_name="sagemaker"
)
operation_input_args = {
"EdgeDeploymentPlanName": edge_deployment_plan_name,
"ModelConfigs": model_configs,
"DeviceFleetName": device_fleet_name,
"Stages": stages,
"Tags": tags,
}
operation_input_args = Base.populate_chained_attributes(
resource_name="EdgeDeploymentPlan", operation_input_args=operation_input_args
)
logger.debug(f"Input request: {operation_input_args}")
# serialize the input request
operation_input_args = serialize(operation_input_args)
logger.debug(f"Serialized input request: {operation_input_args}")
# create the resource
response = client.create_edge_deployment_plan(**operation_input_args)
logger.debug(f"Response: {response}")
return cls.get(
edge_deployment_plan_name=edge_deployment_plan_name, session=session, region=region
)
[docs]
@classmethod
@Base.add_validate_call
def get(
cls,
edge_deployment_plan_name: str,
next_token: Optional[str] = Unassigned(),
max_results: Optional[int] = Unassigned(),
session: Optional[Session] = None,
region: Optional[str] = None,
) -> Optional["EdgeDeploymentPlan"]:
"""
Get a EdgeDeploymentPlan resource
Parameters:
edge_deployment_plan_name: The name of the deployment plan to describe.
next_token: If the edge deployment plan has enough stages to require tokening, then this is the response from the last list of stages returned.
max_results: The maximum number of results to select (50 by default).
session: Boto3 session.
region: Region name.
Returns:
The EdgeDeploymentPlan resource.
Raises:
botocore.exceptions.ClientError: This exception is raised for AWS service related errors.
The error message and error code can be parsed from the exception as follows:
```
try:
# AWS service call here
except botocore.exceptions.ClientError as e:
error_message = e.response['Error']['Message']
error_code = e.response['Error']['Code']
```
ResourceNotFound: Resource being access is not found.
"""
operation_input_args = {
"EdgeDeploymentPlanName": edge_deployment_plan_name,
"NextToken": next_token,
"MaxResults": max_results,
}
# serialize the input request
operation_input_args = serialize(operation_input_args)
logger.debug(f"Serialized input request: {operation_input_args}")
client = Base.get_sagemaker_client(
session=session, region_name=region, service_name="sagemaker"
)
response = client.describe_edge_deployment_plan(**operation_input_args)
logger.debug(response)
# deserialize the response
transformed_response = transform(response, "DescribeEdgeDeploymentPlanResponse")
edge_deployment_plan = cls(**transformed_response)
return edge_deployment_plan
[docs]
@Base.add_validate_call
def refresh(
self,
max_results: Optional[int] = Unassigned(),
) -> Optional["EdgeDeploymentPlan"]:
"""
Refresh a EdgeDeploymentPlan resource
Returns:
The EdgeDeploymentPlan resource.
Raises:
botocore.exceptions.ClientError: This exception is raised for AWS service related errors.
The error message and error code can be parsed from the exception as follows:
```
try:
# AWS service call here
except botocore.exceptions.ClientError as e:
error_message = e.response['Error']['Message']
error_code = e.response['Error']['Code']
```
ResourceNotFound: Resource being access is not found.
"""
operation_input_args = {
"EdgeDeploymentPlanName": self.edge_deployment_plan_name,
"NextToken": self.next_token,
"MaxResults": max_results,
}
# serialize the input request
operation_input_args = serialize(operation_input_args)
logger.debug(f"Serialized input request: {operation_input_args}")
client = Base.get_sagemaker_client()
response = client.describe_edge_deployment_plan(**operation_input_args)
# deserialize response and update self
transform(response, "DescribeEdgeDeploymentPlanResponse", self)
return self
[docs]
@Base.add_validate_call
def delete(
self,
) -> None:
"""
Delete a EdgeDeploymentPlan resource
Raises:
botocore.exceptions.ClientError: This exception is raised for AWS service related errors.
The error message and error code can be parsed from the exception as follows:
```
try:
# AWS service call here
except botocore.exceptions.ClientError as e:
error_message = e.response['Error']['Message']
error_code = e.response['Error']['Code']
```
ResourceInUse: Resource being accessed is in use.
"""
client = Base.get_sagemaker_client()
operation_input_args = {
"EdgeDeploymentPlanName": self.edge_deployment_plan_name,
}
# serialize the input request
operation_input_args = serialize(operation_input_args)
logger.debug(f"Serialized input request: {operation_input_args}")
client.delete_edge_deployment_plan(**operation_input_args)
logger.info(f"Deleting {self.__class__.__name__} - {self.get_name()}")
[docs]
@classmethod
@Base.add_validate_call
def get_all(
cls,
creation_time_after: Optional[datetime.datetime] = Unassigned(),
creation_time_before: Optional[datetime.datetime] = Unassigned(),
last_modified_time_after: Optional[datetime.datetime] = Unassigned(),
last_modified_time_before: Optional[datetime.datetime] = Unassigned(),
name_contains: Optional[str] = Unassigned(),
device_fleet_name_contains: Optional[str] = Unassigned(),
sort_by: Optional[str] = Unassigned(),
sort_order: Optional[str] = Unassigned(),
session: Optional[Session] = None,
region: Optional[str] = None,
) -> ResourceIterator["EdgeDeploymentPlan"]:
"""
Get all EdgeDeploymentPlan resources
Parameters:
next_token: The response from the last list when returning a list large enough to need tokening.
max_results: The maximum number of results to select (50 by default).
creation_time_after: Selects edge deployment plans created after this time.
creation_time_before: Selects edge deployment plans created before this time.
last_modified_time_after: Selects edge deployment plans that were last updated after this time.
last_modified_time_before: Selects edge deployment plans that were last updated before this time.
name_contains: Selects edge deployment plans with names containing this name.
device_fleet_name_contains: Selects edge deployment plans with a device fleet name containing this name.
sort_by: The column by which to sort the edge deployment plans. Can be one of NAME, DEVICEFLEETNAME, CREATIONTIME, LASTMODIFIEDTIME.
sort_order: The direction of the sorting (ascending or descending).
session: Boto3 session.
region: Region name.
Returns:
Iterator for listed EdgeDeploymentPlan resources.
Raises:
botocore.exceptions.ClientError: This exception is raised for AWS service related errors.
The error message and error code can be parsed from the exception as follows:
```
try:
# AWS service call here
except botocore.exceptions.ClientError as e:
error_message = e.response['Error']['Message']
error_code = e.response['Error']['Code']
```
"""
client = Base.get_sagemaker_client(
session=session, region_name=region, service_name="sagemaker"
)
operation_input_args = {
"CreationTimeAfter": creation_time_after,
"CreationTimeBefore": creation_time_before,
"LastModifiedTimeAfter": last_modified_time_after,
"LastModifiedTimeBefore": last_modified_time_before,
"NameContains": name_contains,
"DeviceFleetNameContains": device_fleet_name_contains,
"SortBy": sort_by,
"SortOrder": sort_order,
}
# serialize the input request
operation_input_args = serialize(operation_input_args)
logger.debug(f"Serialized input request: {operation_input_args}")
return ResourceIterator(
client=client,
list_method="list_edge_deployment_plans",
summaries_key="EdgeDeploymentPlanSummaries",
summary_name="EdgeDeploymentPlanSummary",
resource_cls=EdgeDeploymentPlan,
list_method_kwargs=operation_input_args,
)
[docs]
@Base.add_validate_call
def create_stage(
self,
session: Optional[Session] = None,
region: Optional[str] = None,
) -> None:
"""
Creates a new stage in an existing edge deployment plan.
Parameters:
session: Boto3 session.
region: Region name.
Raises:
botocore.exceptions.ClientError: This exception is raised for AWS service related errors.
The error message and error code can be parsed from the exception as follows:
```
try:
# AWS service call here
except botocore.exceptions.ClientError as e:
error_message = e.response['Error']['Message']
error_code = e.response['Error']['Code']
```
ResourceLimitExceeded: You have exceeded an SageMaker resource limit. For example, you might have too many training jobs created.
"""
operation_input_args = {
"EdgeDeploymentPlanName": self.edge_deployment_plan_name,
"Stages": self.stages,
}
# serialize the input request
operation_input_args = serialize(operation_input_args)
logger.debug(f"Serialized input request: {operation_input_args}")
client = Base.get_sagemaker_client(
session=session, region_name=region, service_name="sagemaker"
)
logger.debug(f"Calling create_edge_deployment_stage API")
response = client.create_edge_deployment_stage(**operation_input_args)
logger.debug(f"Response: {response}")
[docs]
@Base.add_validate_call
def delete_stage(
self,
stage_name: str,
session: Optional[Session] = None,
region: Optional[str] = None,
) -> None:
"""
Delete a stage in an edge deployment plan if (and only if) the stage is inactive.
Parameters:
stage_name: The name of the stage.
session: Boto3 session.
region: Region name.
Raises:
botocore.exceptions.ClientError: This exception is raised for AWS service related errors.
The error message and error code can be parsed from the exception as follows:
```
try:
# AWS service call here
except botocore.exceptions.ClientError as e:
error_message = e.response['Error']['Message']
error_code = e.response['Error']['Code']
```
ResourceInUse: Resource being accessed is in use.
"""
operation_input_args = {
"EdgeDeploymentPlanName": self.edge_deployment_plan_name,
"StageName": stage_name,
}
# serialize the input request
operation_input_args = serialize(operation_input_args)
logger.debug(f"Serialized input request: {operation_input_args}")
client = Base.get_sagemaker_client(
session=session, region_name=region, service_name="sagemaker"
)
logger.debug(f"Calling delete_edge_deployment_stage API")
response = client.delete_edge_deployment_stage(**operation_input_args)
logger.debug(f"Response: {response}")
[docs]
@Base.add_validate_call
def start_stage(
self,
stage_name: str,
session: Optional[Session] = None,
region: Optional[str] = None,
) -> None:
"""
Starts a stage in an edge deployment plan.
Parameters:
stage_name: The name of the stage to start.
session: Boto3 session.
region: Region name.
Raises:
botocore.exceptions.ClientError: This exception is raised for AWS service related errors.
The error message and error code can be parsed from the exception as follows:
```
try:
# AWS service call here
except botocore.exceptions.ClientError as e:
error_message = e.response['Error']['Message']
error_code = e.response['Error']['Code']
```
"""
operation_input_args = {
"EdgeDeploymentPlanName": self.edge_deployment_plan_name,
"StageName": stage_name,
}
# serialize the input request
operation_input_args = serialize(operation_input_args)
logger.debug(f"Serialized input request: {operation_input_args}")
client = Base.get_sagemaker_client(
session=session, region_name=region, service_name="sagemaker"
)
logger.debug(f"Calling start_edge_deployment_stage API")
response = client.start_edge_deployment_stage(**operation_input_args)
logger.debug(f"Response: {response}")
[docs]
@Base.add_validate_call
def stop_stage(
self,
stage_name: str,
session: Optional[Session] = None,
region: Optional[str] = None,
) -> None:
"""
Stops a stage in an edge deployment plan.
Parameters:
stage_name: The name of the stage to stop.
session: Boto3 session.
region: Region name.
Raises:
botocore.exceptions.ClientError: This exception is raised for AWS service related errors.
The error message and error code can be parsed from the exception as follows:
```
try:
# AWS service call here
except botocore.exceptions.ClientError as e:
error_message = e.response['Error']['Message']
error_code = e.response['Error']['Code']
```
"""
operation_input_args = {
"EdgeDeploymentPlanName": self.edge_deployment_plan_name,
"StageName": stage_name,
}
# serialize the input request
operation_input_args = serialize(operation_input_args)
logger.debug(f"Serialized input request: {operation_input_args}")
client = Base.get_sagemaker_client(
session=session, region_name=region, service_name="sagemaker"
)
logger.debug(f"Calling stop_edge_deployment_stage API")
response = client.stop_edge_deployment_stage(**operation_input_args)
logger.debug(f"Response: {response}")
[docs]
@Base.add_validate_call
def get_all_stage_devices(
self,
stage_name: str,
exclude_devices_deployed_in_other_stage: Optional[bool] = Unassigned(),
session: Optional[Session] = None,
region: Optional[str] = None,
) -> ResourceIterator[DeviceDeploymentSummary]:
"""
Lists devices allocated to the stage, containing detailed device information and deployment status.
Parameters:
stage_name: The name of the stage in the deployment.
max_results: The maximum number of requests to select.
exclude_devices_deployed_in_other_stage: Toggle for excluding devices deployed in other stages.
session: Boto3 session.
region: Region name.
Returns:
Iterator for listed DeviceDeploymentSummary.
Raises:
botocore.exceptions.ClientError: This exception is raised for AWS service related errors.
The error message and error code can be parsed from the exception as follows:
```
try:
# AWS service call here
except botocore.exceptions.ClientError as e:
error_message = e.response['Error']['Message']
error_code = e.response['Error']['Code']
```
"""
operation_input_args = {
"EdgeDeploymentPlanName": self.edge_deployment_plan_name,
"ExcludeDevicesDeployedInOtherStage": exclude_devices_deployed_in_other_stage,
"StageName": stage_name,
}
# serialize the input request
operation_input_args = serialize(operation_input_args)
logger.debug(f"Serialized input request: {operation_input_args}")
client = Base.get_sagemaker_client(
session=session, region_name=region, service_name="sagemaker"
)
return ResourceIterator(
client=client,
list_method="list_stage_devices",
summaries_key="DeviceDeploymentSummaries",
summary_name="DeviceDeploymentSummary",
resource_cls=DeviceDeploymentSummary,
list_method_kwargs=operation_input_args,
)
[docs]
class EdgePackagingJob(Base):
"""
Class representing resource EdgePackagingJob
Attributes:
edge_packaging_job_arn: The Amazon Resource Name (ARN) of the edge packaging job.
edge_packaging_job_name: The name of the edge packaging job.
edge_packaging_job_status: The current status of the packaging job.
compilation_job_name: The name of the SageMaker Neo compilation job that is used to locate model artifacts that are being packaged.
model_name: The name of the model.
model_version: The version of the model.
role_arn: The Amazon Resource Name (ARN) of an IAM role that enables Amazon SageMaker to download and upload the model, and to contact Neo.
output_config: The output configuration for the edge packaging job.
resource_key: The Amazon Web Services KMS key to use when encrypting the EBS volume the job run on.
edge_packaging_job_status_message: Returns a message describing the job status and error messages.
creation_time: The timestamp of when the packaging job was created.
last_modified_time: The timestamp of when the job was last updated.
model_artifact: The Amazon Simple Storage (S3) URI where model artifacts ares stored.
model_signature: The signature document of files in the model artifact.
preset_deployment_output: The output of a SageMaker Edge Manager deployable resource.
"""
edge_packaging_job_name: str
edge_packaging_job_arn: Optional[str] = Unassigned()
compilation_job_name: Optional[str] = Unassigned()
model_name: Optional[str] = Unassigned()
model_version: Optional[str] = Unassigned()
role_arn: Optional[str] = Unassigned()
output_config: Optional[EdgeOutputConfig] = Unassigned()
resource_key: Optional[str] = Unassigned()
edge_packaging_job_status: Optional[str] = Unassigned()
edge_packaging_job_status_message: Optional[str] = Unassigned()
creation_time: Optional[datetime.datetime] = Unassigned()
last_modified_time: Optional[datetime.datetime] = Unassigned()
model_artifact: Optional[str] = Unassigned()
model_signature: Optional[str] = Unassigned()
preset_deployment_output: Optional[EdgePresetDeploymentOutput] = Unassigned()
def get_name(self) -> str:
attributes = vars(self)
resource_name = "edge_packaging_job_name"
resource_name_split = resource_name.split("_")
attribute_name_candidates = []
l = len(resource_name_split)
for i in range(0, l):
attribute_name_candidates.append("_".join(resource_name_split[i:l]))
for attribute, value in attributes.items():
if attribute == "name" or attribute in attribute_name_candidates:
return value
logger.error("Name attribute not found for object edge_packaging_job")
return None
def populate_inputs_decorator(create_func):
@functools.wraps(create_func)
def wrapper(*args, **kwargs):
config_schema_for_resource = {
"role_arn": {"type": "string"},
"output_config": {
"s3_output_location": {"type": "string"},
"kms_key_id": {"type": "string"},
},
}
return create_func(
*args,
**Base.get_updated_kwargs_with_configured_attributes(
config_schema_for_resource, "EdgePackagingJob", **kwargs
),
)
return wrapper
[docs]
@classmethod
@populate_inputs_decorator
@Base.add_validate_call
def create(
cls,
edge_packaging_job_name: str,
compilation_job_name: Union[str, object],
model_name: Union[str, object],
model_version: str,
role_arn: str,
output_config: EdgeOutputConfig,
resource_key: Optional[str] = Unassigned(),
tags: Optional[List[Tag]] = Unassigned(),
session: Optional[Session] = None,
region: Optional[str] = None,
) -> Optional["EdgePackagingJob"]:
"""
Create a EdgePackagingJob resource
Parameters:
edge_packaging_job_name: The name of the edge packaging job.
compilation_job_name: The name of the SageMaker Neo compilation job that will be used to locate model artifacts for packaging.
model_name: The name of the model.
model_version: The version of the model.
role_arn: The Amazon Resource Name (ARN) of an IAM role that enables Amazon SageMaker to download and upload the model, and to contact SageMaker Neo.
output_config: Provides information about the output location for the packaged model.
resource_key: The Amazon Web Services KMS key to use when encrypting the EBS volume the edge packaging job runs on.
tags: Creates tags for the packaging job.
session: Boto3 session.
region: Region name.
Returns:
The EdgePackagingJob resource.
Raises:
botocore.exceptions.ClientError: This exception is raised for AWS service related errors.
The error message and error code can be parsed from the exception as follows:
```
try:
# AWS service call here
except botocore.exceptions.ClientError as e:
error_message = e.response['Error']['Message']
error_code = e.response['Error']['Code']
```
ResourceLimitExceeded: You have exceeded an SageMaker resource limit. For example, you might have too many training jobs created.
ConfigSchemaValidationError: Raised when a configuration file does not adhere to the schema
LocalConfigNotFoundError: Raised when a configuration file is not found in local file system
S3ConfigNotFoundError: Raised when a configuration file is not found in S3
"""
logger.info("Creating edge_packaging_job resource.")
client = Base.get_sagemaker_client(
session=session, region_name=region, service_name="sagemaker"
)
operation_input_args = {
"EdgePackagingJobName": edge_packaging_job_name,
"CompilationJobName": compilation_job_name,
"ModelName": model_name,
"ModelVersion": model_version,
"RoleArn": role_arn,
"OutputConfig": output_config,
"ResourceKey": resource_key,
"Tags": tags,
}
operation_input_args = Base.populate_chained_attributes(
resource_name="EdgePackagingJob", operation_input_args=operation_input_args
)
logger.debug(f"Input request: {operation_input_args}")
# serialize the input request
operation_input_args = serialize(operation_input_args)
logger.debug(f"Serialized input request: {operation_input_args}")
# create the resource
response = client.create_edge_packaging_job(**operation_input_args)
logger.debug(f"Response: {response}")
return cls.get(
edge_packaging_job_name=edge_packaging_job_name, session=session, region=region
)
[docs]
@classmethod
@Base.add_validate_call
def get(
cls,
edge_packaging_job_name: str,
session: Optional[Session] = None,
region: Optional[str] = None,
) -> Optional["EdgePackagingJob"]:
"""
Get a EdgePackagingJob resource
Parameters:
edge_packaging_job_name: The name of the edge packaging job.
session: Boto3 session.
region: Region name.
Returns:
The EdgePackagingJob resource.
Raises:
botocore.exceptions.ClientError: This exception is raised for AWS service related errors.
The error message and error code can be parsed from the exception as follows:
```
try:
# AWS service call here
except botocore.exceptions.ClientError as e:
error_message = e.response['Error']['Message']
error_code = e.response['Error']['Code']
```
ResourceNotFound: Resource being access is not found.
"""
operation_input_args = {
"EdgePackagingJobName": edge_packaging_job_name,
}
# serialize the input request
operation_input_args = serialize(operation_input_args)
logger.debug(f"Serialized input request: {operation_input_args}")
client = Base.get_sagemaker_client(
session=session, region_name=region, service_name="sagemaker"
)
response = client.describe_edge_packaging_job(**operation_input_args)
logger.debug(response)
# deserialize the response
transformed_response = transform(response, "DescribeEdgePackagingJobResponse")
edge_packaging_job = cls(**transformed_response)
return edge_packaging_job
[docs]
@Base.add_validate_call
def refresh(
self,
) -> Optional["EdgePackagingJob"]:
"""
Refresh a EdgePackagingJob resource
Returns:
The EdgePackagingJob resource.
Raises:
botocore.exceptions.ClientError: This exception is raised for AWS service related errors.
The error message and error code can be parsed from the exception as follows:
```
try:
# AWS service call here
except botocore.exceptions.ClientError as e:
error_message = e.response['Error']['Message']
error_code = e.response['Error']['Code']
```
ResourceNotFound: Resource being access is not found.
"""
operation_input_args = {
"EdgePackagingJobName": self.edge_packaging_job_name,
}
# serialize the input request
operation_input_args = serialize(operation_input_args)
logger.debug(f"Serialized input request: {operation_input_args}")
client = Base.get_sagemaker_client()
response = client.describe_edge_packaging_job(**operation_input_args)
# deserialize response and update self
transform(response, "DescribeEdgePackagingJobResponse", self)
return self
[docs]
@Base.add_validate_call
def stop(self) -> None:
"""
Stop a EdgePackagingJob resource
Raises:
botocore.exceptions.ClientError: This exception is raised for AWS service related errors.
The error message and error code can be parsed from the exception as follows:
```
try:
# AWS service call here
except botocore.exceptions.ClientError as e:
error_message = e.response['Error']['Message']
error_code = e.response['Error']['Code']
```
"""
client = SageMakerClient().client
operation_input_args = {
"EdgePackagingJobName": self.edge_packaging_job_name,
}
# serialize the input request
operation_input_args = serialize(operation_input_args)
logger.debug(f"Serialized input request: {operation_input_args}")
client.stop_edge_packaging_job(**operation_input_args)
logger.info(f"Stopping {self.__class__.__name__} - {self.get_name()}")
[docs]
@Base.add_validate_call
def wait(
self,
poll: int = 5,
timeout: Optional[int] = None,
) -> None:
"""
Wait for a EdgePackagingJob resource.
Parameters:
poll: The number of seconds to wait between each poll.
timeout: The maximum number of seconds to wait before timing out.
Raises:
TimeoutExceededError: If the resource does not reach a terminal state before the timeout.
FailedStatusError: If the resource reaches a failed state.
WaiterError: Raised when an error occurs while waiting.
"""
terminal_states = ["COMPLETED", "FAILED", "STOPPED"]
start_time = time.time()
progress = Progress(
SpinnerColumn("bouncingBar"),
TextColumn("{task.description}"),
TimeElapsedColumn(),
)
progress.add_task("Waiting for EdgePackagingJob...")
status = Status("Current status:")
with Live(
Panel(
Group(progress, status),
title="Wait Log Panel",
border_style=Style(color=Color.BLUE.value),
),
transient=True,
):
while True:
self.refresh()
current_status = self.edge_packaging_job_status
status.update(f"Current status: [bold]{current_status}")
if current_status in terminal_states:
logger.info(f"Final Resource Status: [bold]{current_status}")
if "failed" in current_status.lower():
raise FailedStatusError(
resource_type="EdgePackagingJob",
status=current_status,
reason=self.edge_packaging_job_status_message,
)
return
if timeout is not None and time.time() - start_time >= timeout:
raise TimeoutExceededError(
resouce_type="EdgePackagingJob", status=current_status
)
time.sleep(poll)
[docs]
@classmethod
@Base.add_validate_call
def get_all(
cls,
creation_time_after: Optional[datetime.datetime] = Unassigned(),
creation_time_before: Optional[datetime.datetime] = Unassigned(),
last_modified_time_after: Optional[datetime.datetime] = Unassigned(),
last_modified_time_before: Optional[datetime.datetime] = Unassigned(),
name_contains: Optional[str] = Unassigned(),
model_name_contains: Optional[str] = Unassigned(),
status_equals: Optional[str] = Unassigned(),
sort_by: Optional[str] = Unassigned(),
sort_order: Optional[str] = Unassigned(),
session: Optional[Session] = None,
region: Optional[str] = None,
) -> ResourceIterator["EdgePackagingJob"]:
"""
Get all EdgePackagingJob resources
Parameters:
next_token: The response from the last list when returning a list large enough to need tokening.
max_results: Maximum number of results to select.
creation_time_after: Select jobs where the job was created after specified time.
creation_time_before: Select jobs where the job was created before specified time.
last_modified_time_after: Select jobs where the job was updated after specified time.
last_modified_time_before: Select jobs where the job was updated before specified time.
name_contains: Filter for jobs containing this name in their packaging job name.
model_name_contains: Filter for jobs where the model name contains this string.
status_equals: The job status to filter for.
sort_by: Use to specify what column to sort by.
sort_order: What direction to sort by.
session: Boto3 session.
region: Region name.
Returns:
Iterator for listed EdgePackagingJob resources.
Raises:
botocore.exceptions.ClientError: This exception is raised for AWS service related errors.
The error message and error code can be parsed from the exception as follows:
```
try:
# AWS service call here
except botocore.exceptions.ClientError as e:
error_message = e.response['Error']['Message']
error_code = e.response['Error']['Code']
```
"""
client = Base.get_sagemaker_client(
session=session, region_name=region, service_name="sagemaker"
)
operation_input_args = {
"CreationTimeAfter": creation_time_after,
"CreationTimeBefore": creation_time_before,
"LastModifiedTimeAfter": last_modified_time_after,
"LastModifiedTimeBefore": last_modified_time_before,
"NameContains": name_contains,
"ModelNameContains": model_name_contains,
"StatusEquals": status_equals,
"SortBy": sort_by,
"SortOrder": sort_order,
}
# serialize the input request
operation_input_args = serialize(operation_input_args)
logger.debug(f"Serialized input request: {operation_input_args}")
return ResourceIterator(
client=client,
list_method="list_edge_packaging_jobs",
summaries_key="EdgePackagingJobSummaries",
summary_name="EdgePackagingJobSummary",
resource_cls=EdgePackagingJob,
list_method_kwargs=operation_input_args,
)
[docs]
class Endpoint(Base):
"""
Class representing resource Endpoint
Attributes:
endpoint_name: Name of the endpoint.
endpoint_arn: The Amazon Resource Name (ARN) of the endpoint.
endpoint_status: The status of the endpoint. OutOfService: Endpoint is not available to take incoming requests. Creating: CreateEndpoint is executing. Updating: UpdateEndpoint or UpdateEndpointWeightsAndCapacities is executing. SystemUpdating: Endpoint is undergoing maintenance and cannot be updated or deleted or re-scaled until it has completed. This maintenance operation does not change any customer-specified values such as VPC config, KMS encryption, model, instance type, or instance count. RollingBack: Endpoint fails to scale up or down or change its variant weight and is in the process of rolling back to its previous configuration. Once the rollback completes, endpoint returns to an InService status. This transitional status only applies to an endpoint that has autoscaling enabled and is undergoing variant weight or capacity changes as part of an UpdateEndpointWeightsAndCapacities call or when the UpdateEndpointWeightsAndCapacities operation is called explicitly. InService: Endpoint is available to process incoming requests. Deleting: DeleteEndpoint is executing. Failed: Endpoint could not be created, updated, or re-scaled. Use the FailureReason value returned by DescribeEndpoint for information about the failure. DeleteEndpoint is the only operation that can be performed on a failed endpoint. UpdateRollbackFailed: Both the rolling deployment and auto-rollback failed. Your endpoint is in service with a mix of the old and new endpoint configurations. For information about how to remedy this issue and restore the endpoint's status to InService, see Rolling Deployments.
creation_time: A timestamp that shows when the endpoint was created.
last_modified_time: A timestamp that shows when the endpoint was last modified.
endpoint_config_name: The name of the endpoint configuration associated with this endpoint.
production_variants: An array of ProductionVariantSummary objects, one for each model hosted behind this endpoint.
data_capture_config:
failure_reason: If the status of the endpoint is Failed, the reason why it failed.
last_deployment_config: The most recent deployment configuration for the endpoint.
async_inference_config: Returns the description of an endpoint configuration created using the CreateEndpointConfig API.
pending_deployment_summary: Returns the summary of an in-progress deployment. This field is only returned when the endpoint is creating or updating with a new endpoint configuration.
explainer_config: The configuration parameters for an explainer.
shadow_production_variants: An array of ProductionVariantSummary objects, one for each model that you want to host at this endpoint in shadow mode with production traffic replicated from the model specified on ProductionVariants.
"""
endpoint_name: str
endpoint_arn: Optional[str] = Unassigned()
endpoint_config_name: Optional[str] = Unassigned()
production_variants: Optional[List[ProductionVariantSummary]] = Unassigned()
data_capture_config: Optional[DataCaptureConfigSummary] = Unassigned()
endpoint_status: Optional[str] = Unassigned()
failure_reason: Optional[str] = Unassigned()
creation_time: Optional[datetime.datetime] = Unassigned()
last_modified_time: Optional[datetime.datetime] = Unassigned()
last_deployment_config: Optional[DeploymentConfig] = Unassigned()
async_inference_config: Optional[AsyncInferenceConfig] = Unassigned()
pending_deployment_summary: Optional[PendingDeploymentSummary] = Unassigned()
explainer_config: Optional[ExplainerConfig] = Unassigned()
shadow_production_variants: Optional[List[ProductionVariantSummary]] = Unassigned()
def get_name(self) -> str:
attributes = vars(self)
resource_name = "endpoint_name"
resource_name_split = resource_name.split("_")
attribute_name_candidates = []
l = len(resource_name_split)
for i in range(0, l):
attribute_name_candidates.append("_".join(resource_name_split[i:l]))
for attribute, value in attributes.items():
if attribute == "name" or attribute in attribute_name_candidates:
return value
logger.error("Name attribute not found for object endpoint")
return None
def populate_inputs_decorator(create_func):
@functools.wraps(create_func)
def wrapper(*args, **kwargs):
config_schema_for_resource = {
"data_capture_config": {
"destination_s3_uri": {"type": "string"},
"kms_key_id": {"type": "string"},
},
"async_inference_config": {
"output_config": {
"kms_key_id": {"type": "string"},
"s3_output_path": {"type": "string"},
"s3_failure_path": {"type": "string"},
}
},
}
return create_func(
*args,
**Base.get_updated_kwargs_with_configured_attributes(
config_schema_for_resource, "Endpoint", **kwargs
),
)
return wrapper
[docs]
@classmethod
@populate_inputs_decorator
@Base.add_validate_call
def create(
cls,
endpoint_name: str,
endpoint_config_name: Union[str, object],
deployment_config: Optional[DeploymentConfig] = Unassigned(),
tags: Optional[List[Tag]] = Unassigned(),
session: Optional[Session] = None,
region: Optional[str] = None,
) -> Optional["Endpoint"]:
"""
Create a Endpoint resource
Parameters:
endpoint_name: The name of the endpoint.The name must be unique within an Amazon Web Services Region in your Amazon Web Services account. The name is case-insensitive in CreateEndpoint, but the case is preserved and must be matched in InvokeEndpoint.
endpoint_config_name: The name of an endpoint configuration. For more information, see CreateEndpointConfig.
deployment_config:
tags: An array of key-value pairs. You can use tags to categorize your Amazon Web Services resources in different ways, for example, by purpose, owner, or environment. For more information, see Tagging Amazon Web Services Resources.
session: Boto3 session.
region: Region name.
Returns:
The Endpoint resource.
Raises:
botocore.exceptions.ClientError: This exception is raised for AWS service related errors.
The error message and error code can be parsed from the exception as follows:
```
try:
# AWS service call here
except botocore.exceptions.ClientError as e:
error_message = e.response['Error']['Message']
error_code = e.response['Error']['Code']
```
ResourceLimitExceeded: You have exceeded an SageMaker resource limit. For example, you might have too many training jobs created.
ConfigSchemaValidationError: Raised when a configuration file does not adhere to the schema
LocalConfigNotFoundError: Raised when a configuration file is not found in local file system
S3ConfigNotFoundError: Raised when a configuration file is not found in S3
"""
logger.info("Creating endpoint resource.")
client = Base.get_sagemaker_client(
session=session, region_name=region, service_name="sagemaker"
)
operation_input_args = {
"EndpointName": endpoint_name,
"EndpointConfigName": endpoint_config_name,
"DeploymentConfig": deployment_config,
"Tags": tags,
}
operation_input_args = Base.populate_chained_attributes(
resource_name="Endpoint", operation_input_args=operation_input_args
)
logger.debug(f"Input request: {operation_input_args}")
# serialize the input request
operation_input_args = serialize(operation_input_args)
logger.debug(f"Serialized input request: {operation_input_args}")
# create the resource
response = client.create_endpoint(**operation_input_args)
logger.debug(f"Response: {response}")
return cls.get(endpoint_name=endpoint_name, session=session, region=region)
[docs]
@classmethod
@Base.add_validate_call
def get(
cls,
endpoint_name: str,
session: Optional[Session] = None,
region: Optional[str] = None,
) -> Optional["Endpoint"]:
"""
Get a Endpoint resource
Parameters:
endpoint_name: The name of the endpoint.
session: Boto3 session.
region: Region name.
Returns:
The Endpoint resource.
Raises:
botocore.exceptions.ClientError: This exception is raised for AWS service related errors.
The error message and error code can be parsed from the exception as follows:
```
try:
# AWS service call here
except botocore.exceptions.ClientError as e:
error_message = e.response['Error']['Message']
error_code = e.response['Error']['Code']
```
"""
operation_input_args = {
"EndpointName": endpoint_name,
}
# serialize the input request
operation_input_args = serialize(operation_input_args)
logger.debug(f"Serialized input request: {operation_input_args}")
client = Base.get_sagemaker_client(
session=session, region_name=region, service_name="sagemaker"
)
response = client.describe_endpoint(**operation_input_args)
logger.debug(response)
# deserialize the response
transformed_response = transform(response, "DescribeEndpointOutput")
endpoint = cls(**transformed_response)
return endpoint
[docs]
@Base.add_validate_call
def refresh(
self,
) -> Optional["Endpoint"]:
"""
Refresh a Endpoint resource
Returns:
The Endpoint resource.
Raises:
botocore.exceptions.ClientError: This exception is raised for AWS service related errors.
The error message and error code can be parsed from the exception as follows:
```
try:
# AWS service call here
except botocore.exceptions.ClientError as e:
error_message = e.response['Error']['Message']
error_code = e.response['Error']['Code']
```
"""
operation_input_args = {
"EndpointName": self.endpoint_name,
}
# serialize the input request
operation_input_args = serialize(operation_input_args)
logger.debug(f"Serialized input request: {operation_input_args}")
client = Base.get_sagemaker_client()
response = client.describe_endpoint(**operation_input_args)
# deserialize response and update self
transform(response, "DescribeEndpointOutput", self)
return self
[docs]
@populate_inputs_decorator
@Base.add_validate_call
def update(
self,
retain_all_variant_properties: Optional[bool] = Unassigned(),
exclude_retained_variant_properties: Optional[List[VariantProperty]] = Unassigned(),
deployment_config: Optional[DeploymentConfig] = Unassigned(),
retain_deployment_config: Optional[bool] = Unassigned(),
) -> Optional["Endpoint"]:
"""
Update a Endpoint resource
Parameters:
retain_all_variant_properties: When updating endpoint resources, enables or disables the retention of variant properties, such as the instance count or the variant weight. To retain the variant properties of an endpoint when updating it, set RetainAllVariantProperties to true. To use the variant properties specified in a new EndpointConfig call when updating an endpoint, set RetainAllVariantProperties to false. The default is false.
exclude_retained_variant_properties: When you are updating endpoint resources with RetainAllVariantProperties, whose value is set to true, ExcludeRetainedVariantProperties specifies the list of type VariantProperty to override with the values provided by EndpointConfig. If you don't specify a value for ExcludeRetainedVariantProperties, no variant properties are overridden.
deployment_config: The deployment configuration for an endpoint, which contains the desired deployment strategy and rollback configurations.
retain_deployment_config: Specifies whether to reuse the last deployment configuration. The default value is false (the configuration is not reused).
Returns:
The Endpoint resource.
Raises:
botocore.exceptions.ClientError: This exception is raised for AWS service related errors.
The error message and error code can be parsed from the exception as follows:
```
try:
# AWS service call here
except botocore.exceptions.ClientError as e:
error_message = e.response['Error']['Message']
error_code = e.response['Error']['Code']
```
ResourceLimitExceeded: You have exceeded an SageMaker resource limit. For example, you might have too many training jobs created.
"""
logger.info("Updating endpoint resource.")
client = Base.get_sagemaker_client()
operation_input_args = {
"EndpointName": self.endpoint_name,
"EndpointConfigName": self.endpoint_config_name,
"RetainAllVariantProperties": retain_all_variant_properties,
"ExcludeRetainedVariantProperties": exclude_retained_variant_properties,
"DeploymentConfig": deployment_config,
"RetainDeploymentConfig": retain_deployment_config,
}
logger.debug(f"Input request: {operation_input_args}")
# serialize the input request
operation_input_args = serialize(operation_input_args)
logger.debug(f"Serialized input request: {operation_input_args}")
# create the resource
response = client.update_endpoint(**operation_input_args)
logger.debug(f"Response: {response}")
self.refresh()
return self
[docs]
@Base.add_validate_call
def delete(
self,
) -> None:
"""
Delete a Endpoint resource
Raises:
botocore.exceptions.ClientError: This exception is raised for AWS service related errors.
The error message and error code can be parsed from the exception as follows:
```
try:
# AWS service call here
except botocore.exceptions.ClientError as e:
error_message = e.response['Error']['Message']
error_code = e.response['Error']['Code']
```
"""
client = Base.get_sagemaker_client()
operation_input_args = {
"EndpointName": self.endpoint_name,
}
# serialize the input request
operation_input_args = serialize(operation_input_args)
logger.debug(f"Serialized input request: {operation_input_args}")
client.delete_endpoint(**operation_input_args)
logger.info(f"Deleting {self.__class__.__name__} - {self.get_name()}")
[docs]
@Base.add_validate_call
def wait_for_status(
self,
target_status: Literal[
"OutOfService",
"Creating",
"Updating",
"SystemUpdating",
"RollingBack",
"InService",
"Deleting",
"Failed",
"UpdateRollbackFailed",
],
poll: int = 5,
timeout: Optional[int] = None,
) -> None:
"""
Wait for a Endpoint resource to reach certain status.
Parameters:
target_status: The status to wait for.
poll: The number of seconds to wait between each poll.
timeout: The maximum number of seconds to wait before timing out.
Raises:
TimeoutExceededError: If the resource does not reach a terminal state before the timeout.
FailedStatusError: If the resource reaches a failed state.
WaiterError: Raised when an error occurs while waiting.
"""
start_time = time.time()
progress = Progress(
SpinnerColumn("bouncingBar"),
TextColumn("{task.description}"),
TimeElapsedColumn(),
)
progress.add_task(f"Waiting for Endpoint to reach [bold]{target_status} status...")
status = Status("Current status:")
with Live(
Panel(
Group(progress, status),
title="Wait Log Panel",
border_style=Style(color=Color.BLUE.value),
),
transient=True,
):
while True:
self.refresh()
current_status = self.endpoint_status
status.update(f"Current status: [bold]{current_status}")
if target_status == current_status:
logger.info(f"Final Resource Status: [bold]{current_status}")
return
if "failed" in current_status.lower():
raise FailedStatusError(
resource_type="Endpoint", status=current_status, reason=self.failure_reason
)
if timeout is not None and time.time() - start_time >= timeout:
raise TimeoutExceededError(resouce_type="Endpoint", status=current_status)
time.sleep(poll)
[docs]
@Base.add_validate_call
def wait_for_delete(
self,
poll: int = 5,
timeout: Optional[int] = None,
) -> None:
"""
Wait for a Endpoint resource to be deleted.
Parameters:
poll: The number of seconds to wait between each poll.
timeout: The maximum number of seconds to wait before timing out.
Raises:
botocore.exceptions.ClientError: This exception is raised for AWS service related errors.
The error message and error code can be parsed from the exception as follows:
```
try:
# AWS service call here
except botocore.exceptions.ClientError as e:
error_message = e.response['Error']['Message']
error_code = e.response['Error']['Code']
```
TimeoutExceededError: If the resource does not reach a terminal state before the timeout.
DeleteFailedStatusError: If the resource reaches a failed state.
WaiterError: Raised when an error occurs while waiting.
"""
start_time = time.time()
progress = Progress(
SpinnerColumn("bouncingBar"),
TextColumn("{task.description}"),
TimeElapsedColumn(),
)
progress.add_task("Waiting for Endpoint to be deleted...")
status = Status("Current status:")
with Live(
Panel(
Group(progress, status),
title="Wait Log Panel",
border_style=Style(color=Color.BLUE.value),
)
):
while True:
try:
self.refresh()
current_status = self.endpoint_status
status.update(f"Current status: [bold]{current_status}")
if timeout is not None and time.time() - start_time >= timeout:
raise TimeoutExceededError(resouce_type="Endpoint", status=current_status)
except botocore.exceptions.ClientError as e:
error_code = e.response["Error"]["Code"]
if "ResourceNotFound" in error_code or "ValidationException" in error_code:
logger.info("Resource was not found. It may have been deleted.")
return
raise e
time.sleep(poll)
[docs]
@classmethod
@Base.add_validate_call
def get_all(
cls,
sort_by: Optional[str] = Unassigned(),
sort_order: Optional[str] = Unassigned(),
name_contains: Optional[str] = Unassigned(),
creation_time_before: Optional[datetime.datetime] = Unassigned(),
creation_time_after: Optional[datetime.datetime] = Unassigned(),
last_modified_time_before: Optional[datetime.datetime] = Unassigned(),
last_modified_time_after: Optional[datetime.datetime] = Unassigned(),
status_equals: Optional[str] = Unassigned(),
session: Optional[Session] = None,
region: Optional[str] = None,
) -> ResourceIterator["Endpoint"]:
"""
Get all Endpoint resources
Parameters:
sort_by: Sorts the list of results. The default is CreationTime.
sort_order: The sort order for results. The default is Descending.
next_token: If the result of a ListEndpoints request was truncated, the response includes a NextToken. To retrieve the next set of endpoints, use the token in the next request.
max_results: The maximum number of endpoints to return in the response. This value defaults to 10.
name_contains: A string in endpoint names. This filter returns only endpoints whose name contains the specified string.
creation_time_before: A filter that returns only endpoints that were created before the specified time (timestamp).
creation_time_after: A filter that returns only endpoints with a creation time greater than or equal to the specified time (timestamp).
last_modified_time_before: A filter that returns only endpoints that were modified before the specified timestamp.
last_modified_time_after: A filter that returns only endpoints that were modified after the specified timestamp.
status_equals: A filter that returns only endpoints with the specified status.
session: Boto3 session.
region: Region name.
Returns:
Iterator for listed Endpoint resources.
Raises:
botocore.exceptions.ClientError: This exception is raised for AWS service related errors.
The error message and error code can be parsed from the exception as follows:
```
try:
# AWS service call here
except botocore.exceptions.ClientError as e:
error_message = e.response['Error']['Message']
error_code = e.response['Error']['Code']
```
"""
client = Base.get_sagemaker_client(
session=session, region_name=region, service_name="sagemaker"
)
operation_input_args = {
"SortBy": sort_by,
"SortOrder": sort_order,
"NameContains": name_contains,
"CreationTimeBefore": creation_time_before,
"CreationTimeAfter": creation_time_after,
"LastModifiedTimeBefore": last_modified_time_before,
"LastModifiedTimeAfter": last_modified_time_after,
"StatusEquals": status_equals,
}
# serialize the input request
operation_input_args = serialize(operation_input_args)
logger.debug(f"Serialized input request: {operation_input_args}")
return ResourceIterator(
client=client,
list_method="list_endpoints",
summaries_key="Endpoints",
summary_name="EndpointSummary",
resource_cls=Endpoint,
list_method_kwargs=operation_input_args,
)
[docs]
@Base.add_validate_call
def update_weights_and_capacities(
self,
desired_weights_and_capacities: List[DesiredWeightAndCapacity],
session: Optional[Session] = None,
region: Optional[str] = None,
) -> None:
"""
Updates variant weight of one or more variants associated with an existing endpoint, or capacity of one variant associated with an existing endpoint.
Parameters:
desired_weights_and_capacities: An object that provides new capacity and weight values for a variant.
session: Boto3 session.
region: Region name.
Raises:
botocore.exceptions.ClientError: This exception is raised for AWS service related errors.
The error message and error code can be parsed from the exception as follows:
```
try:
# AWS service call here
except botocore.exceptions.ClientError as e:
error_message = e.response['Error']['Message']
error_code = e.response['Error']['Code']
```
ResourceLimitExceeded: You have exceeded an SageMaker resource limit. For example, you might have too many training jobs created.
"""
operation_input_args = {
"EndpointName": self.endpoint_name,
"DesiredWeightsAndCapacities": desired_weights_and_capacities,
}
# serialize the input request
operation_input_args = serialize(operation_input_args)
logger.debug(f"Serialized input request: {operation_input_args}")
client = Base.get_sagemaker_client(
session=session, region_name=region, service_name="sagemaker"
)
logger.debug(f"Calling update_endpoint_weights_and_capacities API")
response = client.update_endpoint_weights_and_capacities(**operation_input_args)
logger.debug(f"Response: {response}")
[docs]
@Base.add_validate_call
def invoke(
self,
body: Any,
content_type: Optional[str] = Unassigned(),
accept: Optional[str] = Unassigned(),
custom_attributes: Optional[str] = Unassigned(),
target_model: Optional[str] = Unassigned(),
target_variant: Optional[str] = Unassigned(),
target_container_hostname: Optional[str] = Unassigned(),
inference_id: Optional[str] = Unassigned(),
enable_explanations: Optional[str] = Unassigned(),
inference_component_name: Optional[str] = Unassigned(),
session_id: Optional[str] = Unassigned(),
session: Optional[Session] = None,
region: Optional[str] = None,
) -> Optional[InvokeEndpointOutput]:
"""
After you deploy a model into production using Amazon SageMaker hosting services, your client applications use this API to get inferences from the model hosted at the specified endpoint.
Parameters:
body: Provides input data, in the format specified in the ContentType request header. Amazon SageMaker passes all of the data in the body to the model. For information about the format of the request body, see Common Data Formats-Inference.
content_type: The MIME type of the input data in the request body.
accept: The desired MIME type of the inference response from the model container.
custom_attributes: Provides additional information about a request for an inference submitted to a model hosted at an Amazon SageMaker endpoint. The information is an opaque value that is forwarded verbatim. You could use this value, for example, to provide an ID that you can use to track a request or to provide other metadata that a service endpoint was programmed to process. The value must consist of no more than 1024 visible US-ASCII characters as specified in Section 3.3.6. Field Value Components of the Hypertext Transfer Protocol (HTTP/1.1). The code in your model is responsible for setting or updating any custom attributes in the response. If your code does not set this value in the response, an empty value is returned. For example, if a custom attribute represents the trace ID, your model can prepend the custom attribute with Trace ID: in your post-processing function. This feature is currently supported in the Amazon Web Services SDKs but not in the Amazon SageMaker Python SDK.
target_model: The model to request for inference when invoking a multi-model endpoint.
target_variant: Specify the production variant to send the inference request to when invoking an endpoint that is running two or more variants. Note that this parameter overrides the default behavior for the endpoint, which is to distribute the invocation traffic based on the variant weights. For information about how to use variant targeting to perform a/b testing, see Test models in production
target_container_hostname: If the endpoint hosts multiple containers and is configured to use direct invocation, this parameter specifies the host name of the container to invoke.
inference_id: If you provide a value, it is added to the captured data when you enable data capture on the endpoint. For information about data capture, see Capture Data.
enable_explanations: An optional JMESPath expression used to override the EnableExplanations parameter of the ClarifyExplainerConfig API. See the EnableExplanations section in the developer guide for more information.
inference_component_name: If the endpoint hosts one or more inference components, this parameter specifies the name of inference component to invoke.
session_id: Creates a stateful session or identifies an existing one. You can do one of the following: Create a stateful session by specifying the value NEW_SESSION. Send your request to an existing stateful session by specifying the ID of that session. With a stateful session, you can send multiple requests to a stateful model. When you create a session with a stateful model, the model must create the session ID and set the expiration time. The model must also provide that information in the response to your request. You can get the ID and timestamp from the NewSessionId response parameter. For any subsequent request where you specify that session ID, SageMaker routes the request to the same instance that supports the session.
session: Boto3 session.
region: Region name.
Returns:
InvokeEndpointOutput
Raises:
botocore.exceptions.ClientError: This exception is raised for AWS service related errors.
The error message and error code can be parsed from the exception as follows:
```
try:
# AWS service call here
except botocore.exceptions.ClientError as e:
error_message = e.response['Error']['Message']
error_code = e.response['Error']['Code']
```
InternalDependencyException: Your request caused an exception with an internal dependency. Contact customer support.
InternalFailure: An internal failure occurred. Try your request again. If the problem persists, contact Amazon Web Services customer support.
ModelError: Model (owned by the customer in the container) returned 4xx or 5xx error code.
ModelNotReadyException: Either a serverless endpoint variant's resources are still being provisioned, or a multi-model endpoint is still downloading or loading the target model. Wait and try your request again.
ServiceUnavailable: The service is currently unavailable.
ValidationError: There was an error validating your request.
"""
operation_input_args = {
"EndpointName": self.endpoint_name,
"Body": body,
"ContentType": content_type,
"Accept": accept,
"CustomAttributes": custom_attributes,
"TargetModel": target_model,
"TargetVariant": target_variant,
"TargetContainerHostname": target_container_hostname,
"InferenceId": inference_id,
"EnableExplanations": enable_explanations,
"InferenceComponentName": inference_component_name,
"SessionId": session_id,
}
# serialize the input request
operation_input_args = serialize(operation_input_args)
logger.debug(f"Serialized input request: {operation_input_args}")
client = Base.get_sagemaker_client(
session=session, region_name=region, service_name="sagemaker-runtime"
)
logger.