From c2f7905fa4f9515ce87573add6002a7cc5c4203f Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Mon, 16 Dec 2024 14:22:34 -0800 Subject: [PATCH] Fix bedrock inference impl --- .../self_hosted_distro/bedrock.md | 7 +++++++ .../distribution/tests/library_client_test.py | 3 ++- .../remote/inference/bedrock/bedrock.py | 8 ++++---- llama_stack/templates/bedrock/bedrock.py | 20 +++++++++++++++++-- llama_stack/templates/bedrock/run.yaml | 17 +++++++++++++++- 5 files changed, 47 insertions(+), 8 deletions(-) diff --git a/docs/source/distributions/self_hosted_distro/bedrock.md b/docs/source/distributions/self_hosted_distro/bedrock.md index ae03c89da..7dab23655 100644 --- a/docs/source/distributions/self_hosted_distro/bedrock.md +++ b/docs/source/distributions/self_hosted_distro/bedrock.md @@ -28,6 +28,13 @@ The following environment variables can be configured: - `LLAMASTACK_PORT`: Port for the Llama Stack distribution server (default: `5001`) +### Models + +The following models are available by default: + +- `meta-llama/Llama-3.1-8B-Instruct (meta.llama3-1-8b-instruct-v1:0)` +- `meta-llama/Llama-3.1-70B-Instruct (meta.llama3-1-70b-instruct-v1:0)` +- `meta-llama/Llama-3.1-405B-Instruct-FP8 (meta.llama3-1-405b-instruct-v1:0)` ### Prerequisite: API Keys diff --git a/llama_stack/distribution/tests/library_client_test.py b/llama_stack/distribution/tests/library_client_test.py index 955640c2b..a919ab223 100644 --- a/llama_stack/distribution/tests/library_client_test.py +++ b/llama_stack/distribution/tests/library_client_test.py @@ -29,7 +29,8 @@ def main(config_path: str): print("No models found, skipping chat completion test") return - model_id = models[0].identifier + model_id = next(m.identifier for m in models if "8b" in m.identifier.lower()) + print(f"Using model: {model_id}") response = client.inference.chat_completion( messages=[UserMessage(content="What is the capital of France?", role="user")], model_id=model_id, diff --git a/llama_stack/providers/remote/inference/bedrock/bedrock.py b/llama_stack/providers/remote/inference/bedrock/bedrock.py index 96cbcaa67..d5565dd62 100644 --- a/llama_stack/providers/remote/inference/bedrock/bedrock.py +++ b/llama_stack/providers/remote/inference/bedrock/bedrock.py @@ -6,7 +6,7 @@ from typing import * # noqa: F403 import json - +import uuid from botocore.client import BaseClient from llama_models.datatypes import CoreModelId @@ -26,7 +26,7 @@ from llama_stack.providers.utils.bedrock.client import create_bedrock_client from llama_stack.providers.utils.inference.prompt_adapter import content_has_media -model_aliases = [ +MODEL_ALIASES = [ build_model_alias( "meta.llama3-1-8b-instruct-v1:0", CoreModelId.llama3_1_8b_instruct.value, @@ -45,7 +45,7 @@ model_aliases = [ # NOTE: this is not quite tested after the recent refactors class BedrockInferenceAdapter(ModelRegistryHelper, Inference): def __init__(self, config: BedrockConfig) -> None: - ModelRegistryHelper.__init__(self, model_aliases) + ModelRegistryHelper.__init__(self, MODEL_ALIASES) self._config = config self._client = create_bedrock_client(config) @@ -146,7 +146,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference): [ { "toolResult": { - "toolUseId": message.call_id, + "toolUseId": message.call_id or str(uuid.uuid4()), "content": [ {"text": content} for content in content_list ], diff --git a/llama_stack/templates/bedrock/bedrock.py b/llama_stack/templates/bedrock/bedrock.py index c52b56612..8911d159d 100644 --- a/llama_stack/templates/bedrock/bedrock.py +++ b/llama_stack/templates/bedrock/bedrock.py @@ -6,11 +6,13 @@ from pathlib import Path +from llama_models.sku_list import all_registered_models from llama_stack.distribution.datatypes import Provider from llama_stack.providers.inline.memory.faiss.config import FaissImplConfig from llama_stack.templates.template import DistributionTemplate, RunConfigSettings - +from llama_stack.providers.remote.inference.bedrock.bedrock import MODEL_ALIASES +from llama_stack.apis.models import ModelInput def get_distribution_template() -> DistributionTemplate: providers = { @@ -30,6 +32,19 @@ def get_distribution_template() -> DistributionTemplate: config=FaissImplConfig.sample_run_config(f"distributions/{name}"), ) + core_model_to_hf_repo = { + m.descriptor(): m.huggingface_repo for m in all_registered_models() + } + + default_models = [ + ModelInput( + model_id=core_model_to_hf_repo[m.llama_model], + provider_model_id=m.provider_model_id, + provider_id="bedrock", + ) + for m in MODEL_ALIASES + ] + return DistributionTemplate( name=name, distro_type="self_hosted", @@ -37,12 +52,13 @@ def get_distribution_template() -> DistributionTemplate: docker_image=None, template_path=Path(__file__).parent / "doc_template.md", providers=providers, - default_models=[], + default_models=default_models, run_configs={ "run.yaml": RunConfigSettings( provider_overrides={ "memory": [memory_provider], }, + default_models=default_models, ), }, run_config_env_vars={ diff --git a/llama_stack/templates/bedrock/run.yaml b/llama_stack/templates/bedrock/run.yaml index 47885b536..9aa5ca914 100644 --- a/llama_stack/templates/bedrock/run.yaml +++ b/llama_stack/templates/bedrock/run.yaml @@ -69,7 +69,22 @@ metadata_store: namespace: null type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/bedrock}/registry.db -models: [] +models: +- metadata: {} + model_id: meta-llama/Llama-3.1-8B-Instruct + provider_id: bedrock + provider_model_id: meta.llama3-1-8b-instruct-v1:0 + model_type: llm +- metadata: {} + model_id: meta-llama/Llama-3.1-70B-Instruct + provider_id: bedrock + provider_model_id: meta.llama3-1-70b-instruct-v1:0 + model_type: llm +- metadata: {} + model_id: meta-llama/Llama-3.1-405B-Instruct-FP8 + provider_id: bedrock + provider_model_id: meta.llama3-1-405b-instruct-v1:0 + model_type: llm shields: [] memory_banks: [] datasets: []