diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..2dad49c --- /dev/null +++ b/setup.py @@ -0,0 +1,35 @@ +import os +from glob import glob +from os.path import basename +from os.path import splitext + +from setuptools import setup, find_packages + + +def read(fname): + return open(os.path.join(os.path.dirname(__file__), fname)).read() + + +setup( + name='sagemaker-containers', + version='1.0', + description='Open source library for creating containers to run on Amazon SageMaker.', + + packages=find_packages(where='src', exclude=('test',)), + package_dir={'': 'src'}, + py_modules=[splitext(basename(path))[0] for path in glob('src/*.py')], + long_description=read('README.md'), + author='Amazon Web Services', + url='https://github.com/aws/sagemaker-container-support/', + license='Apache License 2.0', + + classifiers=[ + "Development Status :: 5 - Production/Stable", + "Intended Audience :: Developers", + "Natural Language :: English", + "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python", + 'Programming Language :: Python :: 2.7', + 'Programming Language :: Python :: 3.5', + ] +) diff --git a/src/sagemaker_containers/__init__.py b/src/sagemaker_containers/__init__.py new file mode 100644 index 0000000..c566df3 --- /dev/null +++ b/src/sagemaker_containers/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +from sagemaker_containers.app import App +from sagemaker_containers.training_engine import TrainingEngine +from sagemaker_containers.training_environment import TrainingEnvironment + +__all__ = [TrainingEnvironment, TrainingEngine, App] diff --git a/src/sagemaker_containers/app.py b/src/sagemaker_containers/app.py new file mode 100644 index 0000000..39750bd --- /dev/null +++ b/src/sagemaker_containers/app.py @@ -0,0 +1,29 @@ +# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +from sagemaker_containers.training_engine import TrainingEngine + + +class App(object): + def __init__(self): + self.training_engine = TrainingEngine() + + def register_engine(self, engine): + if isinstance(engine, TrainingEngine): + self.training_engine = engine + else: + raise ValueError('Type: %s is not a valid engine type' % type(engine)) + + def run(self): + self.training_engine.run() diff --git a/src/sagemaker_containers/environment.py b/src/sagemaker_containers/environment.py new file mode 100644 index 0000000..be0ef29 --- /dev/null +++ b/src/sagemaker_containers/environment.py @@ -0,0 +1,48 @@ +# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import multiprocessing +import os + +base_dir = 'opt/ml' + + +class ContainerEnvironment(object): + """Provides access to common aspects of the container environment, including + important system characteristics, filesystem locations, and configuration settings. + """ + def __init__(self): + self.base_dir = base_dir + + @property + def model_dir(self): + return os.path.join(self.base_dir, 'model') + + @property + def code_dir(self): + """The directory where user-supplied code will be staged.""" + return os.path.join(self.base_dir, 'code') + + @property + def available_cpus(self): + """The number of cpus available in the current container.""" + return multiprocessing.cpu_count() + + @property + def available_gpus(self): + """The number of gpus available in the current container.""" + pass + + + diff --git a/src/sagemaker_containers/training_engine.py b/src/sagemaker_containers/training_engine.py new file mode 100644 index 0000000..4bbc1ac --- /dev/null +++ b/src/sagemaker_containers/training_engine.py @@ -0,0 +1,43 @@ +# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import sagemaker_containers.training_environment as env +from sagemaker_containers.user_module import UserModule + + +class TrainingEngine(object): + def __init__(self, framework_train=None): + self.framework_train = framework_train + self.environment = env.TrainingEnvironment() + + def framework_train_fn(self): + def decorator(train_fn): + self.framework_train = train_fn + return train_fn + + return decorator + + def run(self): + + try: + user_module = UserModule(self.environment.code_dir, self.framework_train) + + user_module.import_() + + user_module.train(self.environment) + + self.environment.write_success_file() + except Exception as e: + self.environment.write_failure_file('Uncaught exception during training: %s' % e) + raise e diff --git a/src/sagemaker_containers/training_environment.py b/src/sagemaker_containers/training_environment.py new file mode 100644 index 0000000..0999565 --- /dev/null +++ b/src/sagemaker_containers/training_environment.py @@ -0,0 +1,82 @@ +# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import os + +from sagemaker_containers.environment import ContainerEnvironment + + +class TrainingEnvironment(ContainerEnvironment): + """Provides access to aspects of the container environment relevant to training jobs. + """ + + @property + def input_dir(self): + """The base directory for training data and configuration files.""" + return os.path.join(self.base_dir, 'input') + + @property + def input_config_dir(self): + """The directory where standard SageMaker configuration files are located.""" + return os.path.join(self.base_dir, 'input/config') + + @property + def output_dir(self): + """The directory where training success/failure indications will be written.""" + return os.path.join(self.base_dir, 'output') + + @property + def resource_config(self): + """The dict of resource configuration settings.""" + pass + + @property + def hyperparameters(self): + """The dict of hyperparameters that were passed to the CreateTrainingJob API.""" + pass + + @property + def current_host(self): + """The hostname of the current container.""" + pass + + @property + def hosts(self): + """The list of hostnames available to the current training job.""" + pass + + @property + def output_data_dir(self): + """The dir to write non-model training artifacts (e.g. evaluation results) which will be retained by + SageMaker. """ + pass + + @property + def channels(self): + """The dict of training input data channel name to directory with the input files for that channel.""" + pass + + @property + def channel_dirs(self): + """""" + pass + + def load_training_parameters(self, fn): + pass + + def write_success_file(self): + pass + + def write_failure_file(self, message): + pass diff --git a/src/sagemaker_containers/user_module.py b/src/sagemaker_containers/user_module.py new file mode 100644 index 0000000..917c105 --- /dev/null +++ b/src/sagemaker_containers/user_module.py @@ -0,0 +1,29 @@ +# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +import importlib + +import sys + + +class UserModule(object): + def __init__(self, code_dir, framework_train_fn): + self.code_dir = code_dir + self.framework_train_fn = framework_train_fn + self.user_script = 'user_script' + + def import_(self): + sys.path.insert(0, self.code_dir) + self.user_module = importlib.import_module(self.user_script) + + def train(self, training_environment): + self.framework_train_fn(self.user_module, training_environment) \ No newline at end of file diff --git a/test/functional/test_training.py b/test/functional/test_training.py new file mode 100644 index 0000000..384784d --- /dev/null +++ b/test/functional/test_training.py @@ -0,0 +1,87 @@ +# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +import json + +import pytest +import sagemaker_containers.environment as environment +from sagemaker_containers import TrainingEngine, App + + +@pytest.fixture(name='model_file') +def fixture_model_file(tmpdir): + return str(tmpdir.mkdir('output').join('saved_model.json')) + + +@pytest.fixture(name='create_user_module') +def fixture_create_user_module(tmpdir): + environment.base_dir = str(tmpdir) + user_script_file = tmpdir.mkdir('code').join('user_script.py') + + script = "def train(chanel_dirs, hps): return {'trained': True, 'saved': False}" + + user_script_file.write(script) + + +def test_register_training_with_decorator(create_user_module, model_file): + engine = TrainingEngine() + + @engine.framework_train_fn() + def framework_train(user_module, training_environment): + model = user_module.train(training_environment.channel_dirs, training_environment.hyperparameters) + save_model(model, model_file) + + app = App() + app.register_engine(engine) + + app.run() + + assert load_model(model_file) == {'trained': True, 'saved': True} + + +def test_register_training_with_fn(create_user_module, model_file): + def framework_train(user_module, training_environment): + model = user_module.train(training_environment.channel_dirs, training_environment.hyperparameters) + save_model(model, model_file) + + engine = TrainingEngine(framework_train) + + app = App() + app.register_engine(engine) + + app.run() + + assert load_model(model_file) == {'trained': True, 'saved': True} + + +def test_app_run_with_decorator(create_user_module, model_file): + app = App() + + @app.training_engine.framework_train_fn() + def framework_train(user_module, training_environment): + model = user_module.train(training_environment.channel_dirs, training_environment.hyperparameters) + save_model(model, model_file) + + app.run() + + assert load_model(model_file) == {'trained': True, 'saved': True} + + +def save_model(model, model_file): + model['saved'] = True + with open(model_file, 'w') as f: + json.dump(model, f) + + +def load_model(model_file): + with open(model_file, 'r') as f: + return json.load(f)