Fix bedrock inference impl

This commit is contained in:
Ashwin Bharambe 2024-12-16 14:22:34 -08:00
parent eb37fba9da
commit c2f7905fa4
5 changed files with 47 additions and 8 deletions

View file

@ -28,6 +28,13 @@ The following environment variables can be configured:
- `LLAMASTACK_PORT`: Port for the Llama Stack distribution server (default: `5001`)
### Models
The following models are available by default:
- `meta-llama/Llama-3.1-8B-Instruct (meta.llama3-1-8b-instruct-v1:0)`
- `meta-llama/Llama-3.1-70B-Instruct (meta.llama3-1-70b-instruct-v1:0)`
- `meta-llama/Llama-3.1-405B-Instruct-FP8 (meta.llama3-1-405b-instruct-v1:0)`
### Prerequisite: API Keys

View file

@ -29,7 +29,8 @@ def main(config_path: str):
print("No models found, skipping chat completion test")
return
model_id = models[0].identifier
model_id = next(m.identifier for m in models if "8b" in m.identifier.lower())
print(f"Using model: {model_id}")
response = client.inference.chat_completion(
messages=[UserMessage(content="What is the capital of France?", role="user")],
model_id=model_id,

View file

@ -6,7 +6,7 @@
from typing import * # noqa: F403
import json
import uuid
from botocore.client import BaseClient
from llama_models.datatypes import CoreModelId
@ -26,7 +26,7 @@ from llama_stack.providers.utils.bedrock.client import create_bedrock_client
from llama_stack.providers.utils.inference.prompt_adapter import content_has_media
model_aliases = [
MODEL_ALIASES = [
build_model_alias(
"meta.llama3-1-8b-instruct-v1:0",
CoreModelId.llama3_1_8b_instruct.value,
@ -45,7 +45,7 @@ model_aliases = [
# NOTE: this is not quite tested after the recent refactors
class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
def __init__(self, config: BedrockConfig) -> None:
ModelRegistryHelper.__init__(self, model_aliases)
ModelRegistryHelper.__init__(self, MODEL_ALIASES)
self._config = config
self._client = create_bedrock_client(config)
@ -146,7 +146,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
[
{
"toolResult": {
"toolUseId": message.call_id,
"toolUseId": message.call_id or str(uuid.uuid4()),
"content": [
{"text": content} for content in content_list
],

View file

@ -6,11 +6,13 @@
from pathlib import Path
from llama_models.sku_list import all_registered_models
from llama_stack.distribution.datatypes import Provider
from llama_stack.providers.inline.memory.faiss.config import FaissImplConfig
from llama_stack.templates.template import DistributionTemplate, RunConfigSettings
from llama_stack.providers.remote.inference.bedrock.bedrock import MODEL_ALIASES
from llama_stack.apis.models import ModelInput
def get_distribution_template() -> DistributionTemplate:
providers = {
@ -30,6 +32,19 @@ def get_distribution_template() -> DistributionTemplate:
config=FaissImplConfig.sample_run_config(f"distributions/{name}"),
)
core_model_to_hf_repo = {
m.descriptor(): m.huggingface_repo for m in all_registered_models()
}
default_models = [
ModelInput(
model_id=core_model_to_hf_repo[m.llama_model],
provider_model_id=m.provider_model_id,
provider_id="bedrock",
)
for m in MODEL_ALIASES
]
return DistributionTemplate(
name=name,
distro_type="self_hosted",
@ -37,12 +52,13 @@ def get_distribution_template() -> DistributionTemplate:
docker_image=None,
template_path=Path(__file__).parent / "doc_template.md",
providers=providers,
default_models=[],
default_models=default_models,
run_configs={
"run.yaml": RunConfigSettings(
provider_overrides={
"memory": [memory_provider],
},
default_models=default_models,
),
},
run_config_env_vars={

View file

@ -69,7 +69,22 @@ metadata_store:
namespace: null
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/bedrock}/registry.db
models: []
models:
- metadata: {}
model_id: meta-llama/Llama-3.1-8B-Instruct
provider_id: bedrock
provider_model_id: meta.llama3-1-8b-instruct-v1:0
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.1-70B-Instruct
provider_id: bedrock
provider_model_id: meta.llama3-1-70b-instruct-v1:0
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.1-405B-Instruct-FP8
provider_id: bedrock
provider_model_id: meta.llama3-1-405b-instruct-v1:0
model_type: llm
shields: []
memory_banks: []
datasets: []