这是indexloc提供的服务,不要输入任何密码
Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,6 @@ def read(file_name):
install_requires=['boto3', 'six', 'pip'],

extras_require={
'test': ['tox', 'flake8', 'pytest', 'pytest-cov', 'mock', 'sagemaker', 'keras', 'tensorflow', 'numpy']
'test': ['tox', 'flake8', 'pytest', 'pytest-cov', 'mock', 'sagemaker', 'numpy']
}
)
22 changes: 16 additions & 6 deletions src/sagemaker_containers/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,29 @@
# distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
from __future__ import absolute_import

import collections

def split_by_criteria(dictionary, keys): # type: (dict, set) -> (dict, dict)
SplitResultSpec = collections.namedtuple('SplitResultSpec', 'included excluded')


def split_by_criteria(dictionary, keys): # type: (dict, set or list or tuple) -> SplitResultSpec
"""Split a dictionary in two by the provided keys.

Args:
dictionary (dict[str, object]): A Python dictionary
keys (set[str]): Set of keys which will be the split criteria
keys (sequence [str]): A sequence of keys which will be the split criteria

Returns:
criteria (dict[string, object]), not_criteria (dict[string, object]): the result of the split criteria.
`SplitResultSpec` : A collections.namedtuple with the following attributes:

* Args:
included (dict[str, object]: A dictionary with the keys included in the criteria.
excluded (dict[str, object]: A dictionary with the keys not included in the criteria.
"""
dict_matching_criteria = {k: dictionary[k] for k in dictionary.keys() if k in keys}
dict_not_matching_criteria = {k: dictionary[k] for k in dictionary.keys() if k not in keys}
keys = set(keys)
included_items = {k: dictionary[k] for k in dictionary.keys() if k in keys}
excluded_items = {k: dictionary[k] for k in dictionary.keys() if k not in keys}

return dict_matching_criteria, dict_not_matching_criteria
return SplitResultSpec(included=included_items, excluded=excluded_items)
41 changes: 35 additions & 6 deletions src/sagemaker_containers/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,16 @@
from __future__ import absolute_import

import collections
import contextlib
import json
import logging
import multiprocessing
import os
import shlex
import shutil
import subprocess
import sys
import tempfile

import boto3
import six
Expand All @@ -43,14 +46,18 @@
MODEL_PATH = os.path.join(BASE_PATH, 'model') # type: str
INPUT_PATH = os.path.join(BASE_PATH, 'input') # type: str
INPUT_DATA_PATH = os.path.join(INPUT_PATH, 'data') # type: str
INPUT_CONFIG_PATH = os.path.join(INPUT_PATH, 'config') # type: str
INPUT_DATA_CONFIG_PATH = os.path.join(INPUT_PATH, 'config') # type: str
OUTPUT_PATH = os.path.join(BASE_PATH, 'output') # type: str
OUTPUT_DATA_PATH = os.path.join(OUTPUT_PATH, 'data') # type: str

HYPERPARAMETERS_FILE = 'hyperparameters.json' # type: str
RESOURCE_CONFIG_FILE = 'resourceconfig.json' # type: str
INPUT_DATA_CONFIG_FILE = 'inputdataconfig.json' # type: str

HYPERPARAMETERS_PATH = os.path.join(INPUT_DATA_CONFIG_PATH, HYPERPARAMETERS_FILE) # type: str
INPUT_DATA_CONFIG_FILE_PATH = os.path.join(INPUT_DATA_CONFIG_PATH, INPUT_DATA_CONFIG_FILE) # type: str
RESOURCE_CONFIG_PATH = os.path.join(INPUT_DATA_CONFIG_PATH, RESOURCE_CONFIG_FILE) # type: str

PROGRAM_PARAM = 'sagemaker_program' # type: str
SUBMIT_DIR_PARAM = 'sagemaker_submit_directory' # type: str
ENABLE_METRICS_PARAM = 'sagemaker_enable_cloudwatch_metrics' # type: str
Expand Down Expand Up @@ -85,7 +92,7 @@ def read_hyperparameters(): # type: () -> dict
Returns:
(dict[string, object]): a dictionary containing the hyperparameters.
"""
hyperparameters = read_json(os.path.join(INPUT_CONFIG_PATH, HYPERPARAMETERS_FILE))
hyperparameters = read_json(HYPERPARAMETERS_PATH)

try:
return {k: json.loads(v) for k, v in hyperparameters.items()}
Expand Down Expand Up @@ -115,7 +122,7 @@ def read_resource_config(): # type: () -> dict
sorted lexicographically. For example, `['algo-1', 'algo-2', 'algo-3']`
for a three-node cluster.
"""
return read_json(os.path.join(INPUT_CONFIG_PATH, RESOURCE_CONFIG_FILE))
return read_json(RESOURCE_CONFIG_PATH)


def read_input_data_config(): # type: () -> dict
Expand Down Expand Up @@ -148,7 +155,7 @@ def read_input_data_config(): # type: () -> dict
Returns:
input_data_config (dict[string, object]): contents from /opt/ml/input/config/inputdataconfig.json.
"""
return read_json(os.path.join(INPUT_CONFIG_PATH, INPUT_DATA_CONFIG_FILE))
return read_json(INPUT_DATA_CONFIG_FILE_PATH)


def channel_path(channel): # type: (str) -> str
Expand Down Expand Up @@ -587,6 +594,7 @@ def create(cls, session=None): # type: (boto3.Session) -> Environment
input_data_config = read_input_data_config()

hyperparameters = read_hyperparameters()

sagemaker_hyperparameters, hyperparameters = smc.collections.split_by_criteria(hyperparameters,
SAGEMAKER_HYPERPARAMETERS)

Expand All @@ -597,7 +605,7 @@ def create(cls, session=None): # type: (boto3.Session) -> Environment
os.environ[REGION_PARAM_NAME.upper()] = sagemaker_region

return cls(input_dir=INPUT_PATH,
input_config_dir=INPUT_CONFIG_PATH,
input_config_dir=INPUT_DATA_CONFIG_PATH,
model_dir=MODEL_PATH,
output_dir=OUTPUT_PATH,
output_data_dir=OUTPUT_DATA_PATH,
Expand All @@ -616,7 +624,7 @@ def create(cls, session=None): # type: (boto3.Session) -> Environment
)

@staticmethod
def _parse_module_name(program_param):
def _parse_module_name(program_param): # type: (str) -> str
"""Given a module name or a script name, Returns the module name.

This function is used for backwards compatibility.
Expand All @@ -630,3 +638,24 @@ def _parse_module_name(program_param):
if program_param.endswith('.py'):
return program_param[:-3]
return program_param


@contextlib.contextmanager
def temporary_directory(suffix='', prefix='tmp', dir=None): # type: (str, str, str) -> None
"""Create a temporary directory with a context manager. The file is deleted when the context exits.

The prefix, suffix, and dir arguments are the same as for mkstemp().

Args:
suffix (str): If suffix is specified, the file name will end with that suffix, otherwise there will be no
suffix.
prefix (str): If prefix is specified, the file name will begin with that prefix; otherwise,
a default prefix is used.
dir (str): If dir is specified, the file will be created in that directory; otherwise, a default directory is
used.
Returns:
(str) path to the directory
"""
tmpdir = tempfile.mkdtemp(suffix=suffix, prefix=prefix, dir=dir)
yield tmpdir
shutil.rmtree(tmpdir)
40 changes: 15 additions & 25 deletions src/sagemaker_containers/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,41 +38,31 @@ def train(channel_dirs, model_dir): pass
Returns:
(dict) a dictionary with only matching arguments.
"""
args, _, kwargs = signature(fn)
arg_spec = getargspec(fn)

if kwargs:
if arg_spec.keywords:
return dictionary

return smc.collections.split_by_criteria(dictionary, set(args))[0]
return smc.collections.split_by_criteria(dictionary, arg_spec.args).included


def signature(fn): # type: (function) -> ([], [], [])
"""Given a function fn, returns the function args, vargs and kwargs
def getargspec(fn): # type: (function) -> inspect.ArgSpec
"""Get the names and default values of a function's arguments.

Args:
fn (function): a function

Returns:
([], [], []): a tuple containing the function args, vargs and kwargs.
`inspect.ArgSpec`: A collections.namedtuple with the following attributes:

* Args:
args (list): a list of the argument names (it may contain nested lists).
varargs (str): name of the * argument or None.
keywords (str): names of the ** argument or None.
defaults (tuple): an n-tuple of the default values of the last n arguments.
"""
if six.PY2:
arg_spec = inspect.getargspec(fn)
return arg_spec.args, arg_spec.varargs, arg_spec.keywords
return inspect.getargspec(fn)
elif six.PY3:
sig = inspect.signature(fn)

def filter_parameters(kind):
return [
p.name for p in sig.parameters.values()
if p.kind == kind
]

args = filter_parameters(inspect.Parameter.POSITIONAL_OR_KEYWORD)

vargs = filter_parameters(inspect.Parameter.VAR_POSITIONAL)

kwargs = filter_parameters(inspect.Parameter.VAR_KEYWORD)

return (args,
vargs[0] if vargs else None,
kwargs[0] if kwargs else None)
full_arg_spec = inspect.getfullargspec(fn)
return inspect.ArgSpec(full_arg_spec.args, full_arg_spec.varargs, full_arg_spec.varkw, full_arg_spec.defaults)
28 changes: 14 additions & 14 deletions src/sagemaker_containers/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,16 @@
import logging
import os
import shlex
import shutil
import subprocess
import sys
import tarfile
import tempfile
import traceback

import boto3
from six.moves.urllib.parse import urlparse

import sagemaker_containers as smc

logger = logging.getLogger(__name__)

DEFAULT_MODULE_NAME = 'default_user_module_name'
Expand Down Expand Up @@ -97,19 +97,19 @@ def download_and_import(url, name=DEFAULT_MODULE_NAME): # type: (str, str) -> m
Returns:
(module): the imported module
"""
with tempfile.NamedTemporaryFile() as tmp:
s3_download(url, tmp.name)
with smc.environment.temporary_directory() as tmpdir:
dst = os.path.join(tmpdir, 'tar_file')
s3_download(url, dst)

module_path = os.path.join(tmpdir, 'module_dir')

os.makedirs(module_path)

with open(tmp.name, 'rb') as f:
with tarfile.open(mode='r:gz', fileobj=f) as t:
tmpdir = tempfile.mkdtemp()
try:
t.extractall(path=tmpdir)
with tarfile.open(name=dst, mode='r:gz') as t:
t.extractall(path=module_path)

prepare(tmpdir, name)
prepare(module_path, name)

install(tmpdir)
install(module_path)

return importlib.import_module(name)
finally:
shutil.rmtree(tmpdir)
return importlib.import_module(name)
Loading