-
Notifications
You must be signed in to change notification settings - Fork 154
feat: Add sample for Vertex distributed training #4163
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: Add sample for Vertex distributed training #4163
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summary of Changes
Hello @erwinh85, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request introduces a new, end-to-end sample for performing large-scale distributed training of a Llama 3.1 8B model on Vertex AI. It provides all the necessary components, from custom container creation and configuration to job submission and execution, leveraging NVIDIA's NeMo framework and A3 Mega VMs. The goal is to offer a clear, reproducible guide for users looking to run high-performance LLM pre-training on Google Cloud's AI infrastructure.
Highlights
- New Distributed Training Sample: I've added a comprehensive new sample demonstrating distributed training on Google Cloud Vertex AI. This sample focuses on pre-training a Llama 3.1 8B model using NVIDIA's NeMo Framework on A3 Mega VMs (H100 GPUs).
- Custom Container Build Process: The PR includes a
Dockerfile
andcloudbuild.yml
to build a custom NeMo-based container image. This image is tailored for Vertex AI, incorporating necessary dependencies and patches to the NeMo framework for optimal performance and compatibility, including handling GCS data access and specific distributed training configurations. - Automated Job Launching: A Python script (
launch.py
) is provided to simplify the submission of the custom training job to Vertex AI. It reads job parameters from a JSON configuration file, allowing users to easily define machine types, GPU counts, and scheduling strategies (spot, DWS, or reservations). - Robust Distributed Execution: The sample's entrypoint script (
run.py
) handles environment setup, data download from GCS, and orchestrates the multi-nodetorchrun
command. It includes a retry mechanism withtorch.distributed.barrier()
to enhance reliability for distributed jobs, especially when using preemptible instances. - NeMo Configuration and Patches: A detailed NeMo configuration (
llama3_1_8b_pretrain_a3mega.yaml
) is included, defining the Llama 3.1 8B model architecture and training hyperparameters. Several patches are applied to the NeMo framework within the Docker build to address specific integration needs with Vertex AI and improve logging/compatibility.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command>
or @gemini-code-assist <command>
. Below is a summary of the supported commands.
Feature | Command | Description |
---|---|---|
Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/
folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request adds a new sample for distributed training of a Llama 3.1 8B model on Vertex AI. I've found several critical issues, primarily related to incorrect or non-existent package dependencies that will prevent the code from running, as well as a potential bug in the distributed training logic within a patch file. I've also provided suggestions to improve code robustness and adhere to modern practices.
opencv-python-headless==4.11.0.86 | ||
docutils==0.16 | ||
urllib3==2.0.7 | ||
google-cloud-storage==3.0.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
absl-py==2.2.2 | ||
annotated-types==0.7.0 | ||
anyio==4.9.0 | ||
black==25.1.0 | ||
cachetools==5.5.2 | ||
certifi==2025.4.26 | ||
charset-normalizer==3.4.2 | ||
click==8.1.8 | ||
docstring_parser==0.16 | ||
google-api-core==2.24.2 | ||
google-auth==2.40.1 | ||
google-cloud-aiplatform==1.92.0 | ||
google-cloud-bigquery==3.31.0 | ||
google-cloud-core==2.4.3 | ||
google-cloud-resource-manager==1.14.2 | ||
google-cloud-storage==2.19.0 | ||
google-crc32c==1.7.1 | ||
google-genai==1.14.0 | ||
google-resumable-media==2.7.2 | ||
googleapis-common-protos==1.70.0 | ||
grpc-google-iam-v1==0.14.2 | ||
grpcio==1.71.0 | ||
grpcio-status==1.71.0 | ||
h11==0.16.0 | ||
httpcore==1.0.9 | ||
httpx==0.28.1 | ||
idna==3.10 | ||
mypy_extensions==1.1.0 | ||
numpy==2.2.5 | ||
packaging==25.0 | ||
pathspec==0.12.1 | ||
platformdirs==4.3.8 | ||
proto-plus==1.26.1 | ||
protobuf==5.29.4 | ||
pyasn1==0.6.1 | ||
pyasn1_modules==0.4.2 | ||
pydantic==2.11.4 | ||
pydantic_core==2.33.2 | ||
python-dateutil==2.9.0.post0 | ||
pytz==2025.2 | ||
requests==2.32.3 | ||
rsa==4.9.1 | ||
shapely==2.1.0 | ||
six==1.17.0 | ||
sniffio==1.3.1 | ||
typing-inspection==0.4.0 | ||
typing_extensions==4.13.2 | ||
urllib3==2.4.0 | ||
websockets==15.0.1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
from absl import app, flags, logging | ||
from util import cluster_spec | ||
|
||
from retrying import retry |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
reservation_name = config["reservation_name"] | ||
reservation_full_name = ( | ||
f"projects/{project_id}/zones/{zone}/reservations/{reservation_name}" | ||
if "reservation_name" in config | ||
else "" | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Accessing config["reservation_name"]
directly will raise a KeyError
if it's not present. Use config.get("reservation_name")
to safely access the value. The same issue exists for config["zone"]
on line 101.
reservation_name = config["reservation_name"] | |
reservation_full_name = ( | |
f"projects/{project_id}/zones/{zone}/reservations/{reservation_name}" | |
if "reservation_name" in config | |
else "" | |
) | |
reservation_name = config.get("reservation_name") | |
reservation_full_name = ( | |
f"projects/{project_id}/zones/{zone}/reservations/{reservation_name}" | |
if reservation_name and zone | |
else "" | |
) |
class ClusterSpecTest(googletest.TestCase): | ||
|
||
def setUp(self): | ||
super().setUp() | ||
self.curr_env_var = os.environ.copy() | ||
|
||
def tearDown(self): | ||
super().tearDown() | ||
os.environ = self.curr_env_var | ||
|
||
def test_get_cluster_spec_from_env_vars(self): | ||
os.environ["CLUSTER_SPEC"] = "" | ||
os.environ["MASTER_ADDR"] = "127.0.0.1" | ||
os.environ["MASTER_PORT"] = "8080" | ||
os.environ["RANK"] = "0" | ||
os.environ["NNODES"] = "2" | ||
cluster_info = cluster_spec.get_cluster_spec() | ||
self.assertEqual(cluster_info.primary_node_addr, "127.0.0.1") | ||
self.assertEqual(cluster_info.primary_node_port, "8080") | ||
self.assertEqual(cluster_info.node_rank, 0) | ||
self.assertEqual(cluster_info.num_nodes, 2) | ||
|
||
def test_get_cluster_spec_from_cluster_spec(self): | ||
os.environ[ | ||
"CLUSTER_SPEC" | ||
] = """ | ||
{ | ||
"cluster": { | ||
"workerpool0": [ | ||
"127.0.0.1:8080" | ||
], | ||
"workerpool1": [ | ||
"127.0.0.2:8080", | ||
"127.0.0.3:8080" | ||
] | ||
}, | ||
"task": { | ||
"type": "workerpool1", | ||
"index": 0 | ||
} | ||
} | ||
""" | ||
cluster_info = cluster_spec.get_cluster_spec() | ||
self.assertEqual(cluster_info.primary_node_addr, "127.0.0.1") | ||
self.assertEqual(cluster_info.primary_node_port, "8080") | ||
self.assertEqual(cluster_info.node_rank, 1) | ||
self.assertEqual(cluster_info.num_nodes, 3) | ||
|
||
|
||
if __name__ == "__main__": | ||
googletest.main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This test file is set up for googletest
, which is a C++ testing framework. Use the standard library unittest
instead.
import unittest
from . import cluster_spec
# TODO(styer): Use pytest instead
class ClusterSpecTest(unittest.TestCase):
def setUp(self):
super().setUp()
self.curr_env_var = os.environ.copy()
def tearDown(self):
super().tearDown()
os.environ = self.curr_env_var
def test_get_cluster_spec_from_env_vars(self):
os.environ["CLUSTER_SPEC"] = ""
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = "8080"
os.environ["RANK"] = "0"
os.environ["NNODES"] = "2"
cluster_info = cluster_spec.get_cluster_spec()
self.assertEqual(cluster_info.primary_node_addr, "127.0.0.1")
self.assertEqual(cluster_info.primary_node_port, "8080")
self.assertEqual(cluster_info.node_rank, 0)
self.assertEqual(cluster_info.num_nodes, 2)
def test_get_cluster_spec_from_cluster_spec(self):
os.environ[
"CLUSTER_SPEC"
] = """
{
"cluster": {
"workerpool0": [
"127.0.0.1:8080"
],
"workerpool1": [
"127.0.0.2:8080",
"127.0.0.3:8080"
]
},
"task": {
"type": "workerpool1",
"index": 0
}
}
"""
cluster_info = cluster_spec.get_cluster_spec()
self.assertEqual(cluster_info.primary_node_addr, "127.0.0.1")
self.assertEqual(cluster_info.primary_node_port, "8080")
self.assertEqual(cluster_info.node_rank, 1)
self.assertEqual(cluster_info.num_nodes, 3)
if __name__ == "__main__":
unittest.main()
### 4.1. Job Configuration File | ||
Once the container is built, update the job_config.json to set up the training job. | ||
File: job_config.json | ||
```bash |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"zone": "<zone if using reservation>", | ||
"bucket": "<bucket>", | ||
"dataset_bucket": "github-repo/data/third-party/enwiki-latest-pages-articles", | ||
"image_uri": "<docker image uri from artifact registry>>", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ffn_hidden_size: 14336 # 8b: 14336 | 70b: 28672 | 405b: 53248 | ||
num_attention_heads: 32 # 8b: 32 | 70b: 64 | 405b: 128 | ||
num_query_groups: 8 # Number of query groups for group query attention. If None, normal attention is used. 8b: 8 | 70b: 8 | 405b: 16 | ||
init_method_std: 0.01 # Standard deviation of the zero mean normal distribution used for weight initialization.') 8b: 0.01 | 70b: 0.008944 | 405b: 0.02 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
RUN echo "deb [signed-by=/usr/share/keyrings/cloud.google.gpg] http://packages.cloud.google.com/apt cloud-sdk main" | \ | ||
tee -a /etc/apt/sources.list.d/google-cloud-sdk.list && \ | ||
curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | \ | ||
apt-key --keyring /usr/share/keyrings/cloud.google.gpg add - && \ | ||
apt-get update -y && apt-get install google-cloud-sdk -y && \ | ||
rm -rf /var/lib/apt/lists* |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
apt-key
is deprecated. Download the key directly to the keyring location instead.
RUN echo "deb [signed-by=/usr/share/keyrings/cloud.google.gpg] https://packages.cloud.google.com/apt cloud-sdk main" | \
tee -a /etc/apt/sources.list.d/google-cloud-sdk.list && \
curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | gpg --dearmor -o /usr/share/keyrings/cloud.google.gpg && \
apt-get update -y && apt-get install google-cloud-sdk -y && \
rm -rf /var/lib/apt/lists*
RUN patch --verbose -l -d /opt/megatron-lm/megatron/core/datasets -p1 -i /workspace/vdt/patches/local_rank.patch; \ | ||
git -C /workspace/gpu-recipes apply /workspace/vdt/patches/throughput_calc.patch; \ | ||
git -C /opt/NeMo apply /workspace/vdt/patches/nemo2hf.patch; \ | ||
git -C /opt/NeMo apply /workspace/vdt/patches/sigabort.patch; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use &&
instead of ;
to chain commands. This will cause the RUN
step to fail immediately if any command returns a non-zero exit code.
RUN patch --verbose -l -d /opt/megatron-lm/megatron/core/datasets -p1 -i /workspace/vdt/patches/local_rank.patch && \
git -C /workspace/gpu-recipes apply /workspace/vdt/patches/throughput_calc.patch && \
git -C /opt/NeMo apply /workspace/vdt/patches/nemo2hf.patch && \
git -C /opt/NeMo apply /workspace/vdt/patches/sigabort.patch
Run this command to build the container and push it into the Google Artifact Registry. | ||
|
||
```bash | ||
cd "${REPO_ROOT}/a3mega/llama-3-8b-nemo-pretraining" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is incorrect with the new vertexai-samples repo structure.
|
||
|
||
### 4.1. Job Configuration File | ||
Once the container is built, update the job_config.json to set up the training job. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no job_config.json file in folder
"gpu_type": "NVIDIA_H100_MEGA_80GB", | ||
"gpus_per_node": "8", | ||
"recipe_name": "llama3_1_8b_pretrain_a3mega", | ||
"job_prefix": "mchrestkha-spot-", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove reference to 'mchrestkha' and replace with blank
4a23b17
to
bcbbc8d
Compare
lgtm. ran 3 successful jobs with spot vms in us-east 4. |
REQUIRED: Add a summary of your PR here, typically including why the change is needed and what was changed. Include any design alternatives for discussion purposes.
--- YOUR PR SUMMARY GOES HERE ---
REQUIRED: Fill out the below checklists or remove if irrelevant
Official Notebooks
under the notebooks/official folder, follow this mandatory checklist:Official Notebooks
section, pointing to the author or the author's team.Community Notebooks
under the notebooks/community folder:Community Notebooks
section, pointing to the author or the author's team.Community Content
under the community-content folder:Content Directory Name
is descriptive, informative, and includes some of the key products and attributes of your content, so that it is differentiable from other contentCommunity Content
section, pointing to the author or the author's team.