这是indexloc提供的服务,不要输入任何密码
Skip to content

add batch prediction #4127

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
"\n",
"- Deploy Gemma 2 with Hex-LLM on TPU\n",
"- Deploy Gemma with [TGI](https://github.com/huggingface/text-generation-inference) on GPU\n",
"- Perform batch prediction\n",
"\n",
"### File a bug\n",
"\n",
Expand Down Expand Up @@ -119,11 +120,15 @@
"\n",
"# @markdown 1. [Make sure that billing is enabled for your project](https://cloud.google.com/billing/docs/how-to/modify-project).\n",
"\n",
"# @markdown 2. **[Optional]** Set region. If not set, the region will be set automatically according to Colab Enterprise environment.\n",
"# @markdown 2. **[Optional]** [Create a Cloud Storage bucket](https://cloud.google.com/storage/docs/creating-buckets) for storing experiment outputs. Set the BUCKET_URI for the experiment environment. The specified Cloud Storage bucket (`BUCKET_URI`) should be located in the same region as where the notebook was launched. Note that a multi-region bucket (eg. \"us\") is not considered a match for a single region covered by the multi-region range (eg. \"us-central1\"). If not set, a unique GCS bucket will be created instead.\n",
"\n",
"BUCKET_URI = \"gs://\" # @param {type:\"string\"}\n",
"\n",
"# @markdown 3. **[Optional]** Set region. If not set, the region will be set automatically according to Colab Enterprise environment.\n",
"\n",
"REGION = \"\" # @param {type:\"string\"}\n",
"\n",
"# @markdown 3. If you want to run predictions with A100 80GB or H100 GPUs, we recommend using the regions listed below. **NOTE:** Make sure you have associated quota in selected regions. Click the links to see your current quota for each GPU type: [Nvidia A100 80GB](https://console.cloud.google.com/iam-admin/quotas?metric=aiplatform.googleapis.com%2Fcustom_model_serving_nvidia_a100_80gb_gpus), [Nvidia H100 80GB](https://console.cloud.google.com/iam-admin/quotas?metric=aiplatform.googleapis.com%2Fcustom_model_serving_nvidia_h100_gpus). You can request for quota following the instructions at [\"Request a higher quota\"](https://cloud.google.com/docs/quota/view-manage#requesting_higher_quota).\n",
"# @markdown 4. If you want to run predictions with A100 80GB or H100 GPUs, we recommend using the regions listed below. **NOTE:** Make sure you have associated quota in selected regions. Click the links to see your current quota for each GPU type: [Nvidia A100 80GB](https://console.cloud.google.com/iam-admin/quotas?metric=aiplatform.googleapis.com%2Fcustom_model_serving_nvidia_a100_80gb_gpus), [Nvidia H100 80GB](https://console.cloud.google.com/iam-admin/quotas?metric=aiplatform.googleapis.com%2Fcustom_model_serving_nvidia_h100_gpus). You can request for quota following the instructions at [\"Request a higher quota\"](https://cloud.google.com/docs/quota/view-manage#requesting_higher_quota).\n",
"\n",
"# @markdown > | Machine Type | Accelerator Type | Recommended Regions |\n",
"# @markdown | ----------- | ----------- | ----------- |\n",
Expand All @@ -135,8 +140,10 @@
"# Upgrade Vertex AI SDK.\n",
"! pip3 install --upgrade --quiet 'google-cloud-aiplatform==1.97.0'\n",
"\n",
"import datetime\n",
"import importlib\n",
"import os\n",
"import uuid\n",
"from typing import Tuple\n",
"\n",
"from google.cloud import aiplatform\n",
Expand All @@ -158,13 +165,60 @@
"\n",
"# Get the default region for launching jobs.\n",
"if not REGION:\n",
" if not os.environ.get(\"GOOGLE_CLOUD_REGION\"):\n",
" raise ValueError(\n",
" \"REGION must be set. See\"\n",
" \" https://cloud.google.com/vertex-ai/docs/general/locations for\"\n",
" \" available cloud locations.\"\n",
" )\n",
" REGION = os.environ[\"GOOGLE_CLOUD_REGION\"]\n",
"\n",
"# Enable the Vertex AI API and Compute Engine API, if not already.\n",
"print(\"Enabling Vertex AI API and Compute Engine API.\")\n",
"! gcloud services enable aiplatform.googleapis.com compute.googleapis.com\n",
"\n",
"# Cloud Storage bucket for storing the experiment artifacts.\n",
"# A unique GCS bucket will be created for the purpose of this notebook. If you\n",
"# prefer using your own GCS bucket, change the value yourself below.\n",
"now = datetime.datetime.now().strftime(\"%Y%m%d%H%M%S\")\n",
"BUCKET_NAME = \"/\".join(BUCKET_URI.split(\"/\")[:3])\n",
"\n",
"if BUCKET_URI is None or BUCKET_URI.strip() == \"\" or BUCKET_URI == \"gs://\":\n",
" BUCKET_URI = f\"gs://{PROJECT_ID}-tmp-{now}-{str(uuid.uuid4())[:4]}\"\n",
" BUCKET_NAME = \"/\".join(BUCKET_URI.split(\"/\")[:3])\n",
" ! gsutil mb -l {REGION} {BUCKET_URI}\n",
"else:\n",
" assert BUCKET_URI.startswith(\"gs://\"), \"BUCKET_URI must start with `gs://`.\"\n",
" shell_output = ! gsutil ls -Lb {BUCKET_NAME} | grep \"Location constraint:\" | sed \"s/Location constraint://\"\n",
" bucket_region = shell_output[0].strip().lower()\n",
" if bucket_region != REGION:\n",
" raise ValueError(\n",
" \"Bucket region %s is different from notebook region %s\"\n",
" % (bucket_region, REGION)\n",
" )\n",
"print(f\"Using this GCS Bucket: {BUCKET_URI}\")\n",
"\n",
"STAGING_BUCKET = os.path.join(BUCKET_URI, \"temporal\")\n",
"MODEL_BUCKET = os.path.join(BUCKET_URI, \"gemma2\")\n",
"\n",
"\n",
"# Initialize Vertex AI API.\n",
"print(\"Initializing Vertex AI API.\")\n",
"aiplatform.init(project=PROJECT_ID, location=REGION)\n",
"aiplatform.init(project=PROJECT_ID, location=REGION, staging_bucket=STAGING_BUCKET)\n",
"\n",
"# Gets the default SERVICE_ACCOUNT.\n",
"shell_output = ! gcloud projects describe $PROJECT_ID\n",
"project_number = shell_output[-1].split(\":\")[1].strip().replace(\"'\", \"\")\n",
"SERVICE_ACCOUNT = f\"{project_number}-compute@developer.gserviceaccount.com\"\n",
"print(\"Using this default Service Account:\", SERVICE_ACCOUNT)\n",
"\n",
"\n",
"# Provision permissions to the SERVICE_ACCOUNT with the GCS bucket\n",
"! gsutil iam ch serviceAccount:{SERVICE_ACCOUNT}:roles/storage.admin $BUCKET_NAME\n",
"\n",
"! gcloud config set project $PROJECT_ID\n",
"! gcloud projects add-iam-policy-binding --no-user-output-enabled {PROJECT_ID} --member=serviceAccount:{SERVICE_ACCOUNT} --role=\"roles/storage.admin\"\n",
"! gcloud projects add-iam-policy-binding --no-user-output-enabled {PROJECT_ID} --member=serviceAccount:{SERVICE_ACCOUNT} --role=\"roles/aiplatform.user\"\n",
"\n",
"import vertexai\n",
"\n",
Expand Down Expand Up @@ -731,6 +785,70 @@
" print(prediction)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "4kdU8eEvJmr3"
},
"outputs": [],
"source": [
"# @title Batch Predict\n",
"\n",
"# @markdown Batch prediction refers to the process of generating predictions for a large number of data points simultaneously using a machine learning model, rather than making predictions one at a time.\n",
"# @markdown This approach is suitable when real-time responses are not required and processing a large volume of data efficiently is the priority.\n",
"# @markdown For more information, see [Batch prediction overview](https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/batch-prediction-gemini).\n",
"\n",
"import time\n",
"\n",
"from vertexai import model_garden\n",
"\n",
"MODEL_ID = \"gemma-2-27b-it\" # @param [\"gemma-2-2b-it\",\"gemma-2-9b-it\", \"gemma-2-27b-it\"] {allow-input: true, isTemplate: true}\n",
"\n",
"input_dataset = \"gs://cloud-samples-data/vertex-ai/batch-prediction/sample_prompt_batch_predictions.jsonl\" # @param {type:\"string\"}\n",
"output_uri_prefix = os.path.join(MODEL_BUCKET, \"output\")\n",
"\n",
"model_id = f\"google/gemma2@{MODEL_ID}\"\n",
"\n",
"accelerator_type = \"NVIDIA_L4\" # @param [\"NVIDIA_L4\"] {isTemplate: true}\n",
"\n",
"if \"2b\" in MODEL_ID:\n",
" accelerator_type == \"NVIDIA_L4\"\n",
" machine_type = \"g2-standard-12\"\n",
" accelerator_count = 1\n",
"elif \"9b\" in MODEL_ID:\n",
" accelerator_type == \"NVIDIA_L4\"\n",
" machine_type = \"g2-standard-24\"\n",
" accelerator_count = 2\n",
"elif \"27b\" in MODEL_ID:\n",
" accelerator_type == \"NVIDIA_L4\"\n",
" machine_type = \"g2-standard-48\"\n",
" accelerator_count = 4\n",
"else:\n",
" raise ValueError(\"Recommended machine settings not found for model: %s\" % MODEL_ID)\n",
"\n",
"\n",
"model = model_garden.OpenModel(model_id)\n",
"job = model.batch_predict(\n",
" input_dataset=input_dataset,\n",
" output_uri_prefix=output_uri_prefix,\n",
" job_display_name=common_util.get_job_name_with_datetime(prefix=model_id),\n",
" machine_type=machine_type,\n",
" accelerator_type=accelerator_type,\n",
" accelerator_count=accelerator_count,\n",
" starting_replica_count=1,\n",
")\n",
"\n",
"\n",
"while not job.has_ended:\n",
" print(\"Job is running...\")\n",
" time.sleep(60)\n",
" job.refresh()\n",
"\n",
"print(\"The batch predictionjob has finished with status: \", job.state)"
]
},
{
"cell_type": "markdown",
"metadata": {
Expand Down Expand Up @@ -760,7 +878,11 @@
"\n",
"# Delete models.\n",
"for model in models.values():\n",
" model.delete()"
" model.delete()\n",
"\n",
"delete_bucket = False # @param {type:\"boolean\"}\n",
"if delete_bucket:\n",
" ! gsutil -m rm -r $BUCKET_NAME"
]
}
],
Expand Down