这是indexloc提供的服务,不要输入任何密码
Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import os
from glob import glob
from os.path import basename
from os.path import splitext

from setuptools import setup, find_packages


def read(fname):
return open(os.path.join(os.path.dirname(__file__), fname)).read()


setup(
name='sagemaker-containers',
version='1.0',
description='Open source library for creating containers to run on Amazon SageMaker.',

packages=find_packages(where='src', exclude=('test',)),
package_dir={'': 'src'},
py_modules=[splitext(basename(path))[0] for path in glob('src/*.py')],
long_description=read('README.md'),
author='Amazon Web Services',
url='https://github.com/aws/sagemaker-container-support/',
license='Apache License 2.0',

classifiers=[
"Development Status :: 5 - Production/Stable",
"Intended Audience :: Developers",
"Natural Language :: English",
"License :: OSI Approved :: Apache Software License",
"Programming Language :: Python",
'Programming Language :: Python :: 2.7',
'Programming Language :: Python :: 3.5',
]
)
19 changes: 19 additions & 0 deletions src/sagemaker_containers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
from __future__ import absolute_import

from sagemaker_containers.app import App
from sagemaker_containers.training_engine import TrainingEngine
from sagemaker_containers.training_environment import TrainingEnvironment

__all__ = [TrainingEnvironment, TrainingEngine, App]
29 changes: 29 additions & 0 deletions src/sagemaker_containers/app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
from __future__ import absolute_import

from sagemaker_containers.training_engine import TrainingEngine


class App(object):
def __init__(self):
self.training_engine = TrainingEngine()

def register_engine(self, engine):
if isinstance(engine, TrainingEngine):
self.training_engine = engine
else:
raise ValueError('Type: %s is not a valid engine type' % type(engine))

def run(self):
self.training_engine.run()
48 changes: 48 additions & 0 deletions src/sagemaker_containers/environment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
from __future__ import absolute_import

import multiprocessing
import os

base_dir = 'opt/ml'


class ContainerEnvironment(object):
"""Provides access to common aspects of the container environment, including
important system characteristics, filesystem locations, and configuration settings.
"""
def __init__(self):
self.base_dir = base_dir

@property
def model_dir(self):
return os.path.join(self.base_dir, 'model')

@property
def code_dir(self):
"""The directory where user-supplied code will be staged."""
return os.path.join(self.base_dir, 'code')

@property
def available_cpus(self):
"""The number of cpus available in the current container."""
return multiprocessing.cpu_count()

@property
def available_gpus(self):
"""The number of gpus available in the current container."""
pass



43 changes: 43 additions & 0 deletions src/sagemaker_containers/training_engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
from __future__ import absolute_import

import sagemaker_containers.training_environment as env
from sagemaker_containers.user_module import UserModule


class TrainingEngine(object):
def __init__(self, framework_train=None):
self.framework_train = framework_train
self.environment = env.TrainingEnvironment()

def framework_train_fn(self):
def decorator(train_fn):
self.framework_train = train_fn
return train_fn

return decorator

def run(self):

try:
user_module = UserModule(self.environment.code_dir, self.framework_train)

user_module.import_()

user_module.train(self.environment)

self.environment.write_success_file()
except Exception as e:
self.environment.write_failure_file('Uncaught exception during training: %s' % e)
raise e
82 changes: 82 additions & 0 deletions src/sagemaker_containers/training_environment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
from __future__ import absolute_import

import os

from sagemaker_containers.environment import ContainerEnvironment


class TrainingEnvironment(ContainerEnvironment):
"""Provides access to aspects of the container environment relevant to training jobs.
"""

@property
def input_dir(self):
"""The base directory for training data and configuration files."""
return os.path.join(self.base_dir, 'input')

@property
def input_config_dir(self):
"""The directory where standard SageMaker configuration files are located."""
return os.path.join(self.base_dir, 'input/config')

@property
def output_dir(self):
"""The directory where training success/failure indications will be written."""
return os.path.join(self.base_dir, 'output')

@property
def resource_config(self):
"""The dict of resource configuration settings."""
pass

@property
def hyperparameters(self):
"""The dict of hyperparameters that were passed to the CreateTrainingJob API."""
pass

@property
def current_host(self):
"""The hostname of the current container."""
pass

@property
def hosts(self):
"""The list of hostnames available to the current training job."""
pass

@property
def output_data_dir(self):
"""The dir to write non-model training artifacts (e.g. evaluation results) which will be retained by
SageMaker. """
pass

@property
def channels(self):
"""The dict of training input data channel name to directory with the input files for that channel."""
pass

@property
def channel_dirs(self):
""""""
pass

def load_training_parameters(self, fn):
pass

def write_success_file(self):
pass

def write_failure_file(self, message):
pass
29 changes: 29 additions & 0 deletions src/sagemaker_containers/user_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
import importlib

import sys


class UserModule(object):
def __init__(self, code_dir, framework_train_fn):
self.code_dir = code_dir
self.framework_train_fn = framework_train_fn
self.user_script = 'user_script'

def import_(self):
sys.path.insert(0, self.code_dir)
self.user_module = importlib.import_module(self.user_script)

def train(self, training_environment):
self.framework_train_fn(self.user_module, training_environment)
87 changes: 87 additions & 0 deletions test/functional/test_training.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
import json

import pytest
import sagemaker_containers.environment as environment
from sagemaker_containers import TrainingEngine, App


@pytest.fixture(name='model_file')
def fixture_model_file(tmpdir):
return str(tmpdir.mkdir('output').join('saved_model.json'))


@pytest.fixture(name='create_user_module')
def fixture_create_user_module(tmpdir):
environment.base_dir = str(tmpdir)
user_script_file = tmpdir.mkdir('code').join('user_script.py')

script = "def train(chanel_dirs, hps): return {'trained': True, 'saved': False}"

user_script_file.write(script)


def test_register_training_with_decorator(create_user_module, model_file):
engine = TrainingEngine()

@engine.framework_train_fn()
def framework_train(user_module, training_environment):
model = user_module.train(training_environment.channel_dirs, training_environment.hyperparameters)
save_model(model, model_file)

app = App()
app.register_engine(engine)

app.run()

assert load_model(model_file) == {'trained': True, 'saved': True}


def test_register_training_with_fn(create_user_module, model_file):
def framework_train(user_module, training_environment):
model = user_module.train(training_environment.channel_dirs, training_environment.hyperparameters)
save_model(model, model_file)

engine = TrainingEngine(framework_train)

app = App()
app.register_engine(engine)

app.run()

assert load_model(model_file) == {'trained': True, 'saved': True}


def test_app_run_with_decorator(create_user_module, model_file):
app = App()

@app.training_engine.framework_train_fn()
def framework_train(user_module, training_environment):
model = user_module.train(training_environment.channel_dirs, training_environment.hyperparameters)
save_model(model, model_file)

app.run()

assert load_model(model_file) == {'trained': True, 'saved': True}


def save_model(model, model_file):
model['saved'] = True
with open(model_file, 'w') as f:
json.dump(model, f)


def load_model(model_file):
with open(model_file, 'r') as f:
return json.load(f)