Fix bedrock inference impl

This commit is contained in:
Ashwin Bharambe 2024-12-16 14:22:34 -08:00
parent eb37fba9da
commit c2f7905fa4
5 changed files with 47 additions and 8 deletions

View file

@ -28,6 +28,13 @@ The following environment variables can be configured:
- `LLAMASTACK_PORT`: Port for the Llama Stack distribution server (default: `5001`) - `LLAMASTACK_PORT`: Port for the Llama Stack distribution server (default: `5001`)
### Models
The following models are available by default:
- `meta-llama/Llama-3.1-8B-Instruct (meta.llama3-1-8b-instruct-v1:0)`
- `meta-llama/Llama-3.1-70B-Instruct (meta.llama3-1-70b-instruct-v1:0)`
- `meta-llama/Llama-3.1-405B-Instruct-FP8 (meta.llama3-1-405b-instruct-v1:0)`
### Prerequisite: API Keys ### Prerequisite: API Keys

View file

@ -29,7 +29,8 @@ def main(config_path: str):
print("No models found, skipping chat completion test") print("No models found, skipping chat completion test")
return return
model_id = models[0].identifier model_id = next(m.identifier for m in models if "8b" in m.identifier.lower())
print(f"Using model: {model_id}")
response = client.inference.chat_completion( response = client.inference.chat_completion(
messages=[UserMessage(content="What is the capital of France?", role="user")], messages=[UserMessage(content="What is the capital of France?", role="user")],
model_id=model_id, model_id=model_id,

View file

@ -6,7 +6,7 @@
from typing import * # noqa: F403 from typing import * # noqa: F403
import json import json
import uuid
from botocore.client import BaseClient from botocore.client import BaseClient
from llama_models.datatypes import CoreModelId from llama_models.datatypes import CoreModelId
@ -26,7 +26,7 @@ from llama_stack.providers.utils.bedrock.client import create_bedrock_client
from llama_stack.providers.utils.inference.prompt_adapter import content_has_media from llama_stack.providers.utils.inference.prompt_adapter import content_has_media
model_aliases = [ MODEL_ALIASES = [
build_model_alias( build_model_alias(
"meta.llama3-1-8b-instruct-v1:0", "meta.llama3-1-8b-instruct-v1:0",
CoreModelId.llama3_1_8b_instruct.value, CoreModelId.llama3_1_8b_instruct.value,
@ -45,7 +45,7 @@ model_aliases = [
# NOTE: this is not quite tested after the recent refactors # NOTE: this is not quite tested after the recent refactors
class BedrockInferenceAdapter(ModelRegistryHelper, Inference): class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
def __init__(self, config: BedrockConfig) -> None: def __init__(self, config: BedrockConfig) -> None:
ModelRegistryHelper.__init__(self, model_aliases) ModelRegistryHelper.__init__(self, MODEL_ALIASES)
self._config = config self._config = config
self._client = create_bedrock_client(config) self._client = create_bedrock_client(config)
@ -146,7 +146,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
[ [
{ {
"toolResult": { "toolResult": {
"toolUseId": message.call_id, "toolUseId": message.call_id or str(uuid.uuid4()),
"content": [ "content": [
{"text": content} for content in content_list {"text": content} for content in content_list
], ],

View file

@ -6,11 +6,13 @@
from pathlib import Path from pathlib import Path
from llama_models.sku_list import all_registered_models
from llama_stack.distribution.datatypes import Provider from llama_stack.distribution.datatypes import Provider
from llama_stack.providers.inline.memory.faiss.config import FaissImplConfig from llama_stack.providers.inline.memory.faiss.config import FaissImplConfig
from llama_stack.templates.template import DistributionTemplate, RunConfigSettings from llama_stack.templates.template import DistributionTemplate, RunConfigSettings
from llama_stack.providers.remote.inference.bedrock.bedrock import MODEL_ALIASES
from llama_stack.apis.models import ModelInput
def get_distribution_template() -> DistributionTemplate: def get_distribution_template() -> DistributionTemplate:
providers = { providers = {
@ -30,6 +32,19 @@ def get_distribution_template() -> DistributionTemplate:
config=FaissImplConfig.sample_run_config(f"distributions/{name}"), config=FaissImplConfig.sample_run_config(f"distributions/{name}"),
) )
core_model_to_hf_repo = {
m.descriptor(): m.huggingface_repo for m in all_registered_models()
}
default_models = [
ModelInput(
model_id=core_model_to_hf_repo[m.llama_model],
provider_model_id=m.provider_model_id,
provider_id="bedrock",
)
for m in MODEL_ALIASES
]
return DistributionTemplate( return DistributionTemplate(
name=name, name=name,
distro_type="self_hosted", distro_type="self_hosted",
@ -37,12 +52,13 @@ def get_distribution_template() -> DistributionTemplate:
docker_image=None, docker_image=None,
template_path=Path(__file__).parent / "doc_template.md", template_path=Path(__file__).parent / "doc_template.md",
providers=providers, providers=providers,
default_models=[], default_models=default_models,
run_configs={ run_configs={
"run.yaml": RunConfigSettings( "run.yaml": RunConfigSettings(
provider_overrides={ provider_overrides={
"memory": [memory_provider], "memory": [memory_provider],
}, },
default_models=default_models,
), ),
}, },
run_config_env_vars={ run_config_env_vars={

View file

@ -69,7 +69,22 @@ metadata_store:
namespace: null namespace: null
type: sqlite type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/bedrock}/registry.db db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/bedrock}/registry.db
models: [] models:
- metadata: {}
model_id: meta-llama/Llama-3.1-8B-Instruct
provider_id: bedrock
provider_model_id: meta.llama3-1-8b-instruct-v1:0
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.1-70B-Instruct
provider_id: bedrock
provider_model_id: meta.llama3-1-70b-instruct-v1:0
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.1-405B-Instruct-FP8
provider_id: bedrock
provider_model_id: meta.llama3-1-405b-instruct-v1:0
model_type: llm
shields: [] shields: []
memory_banks: [] memory_banks: []
datasets: [] datasets: []