working fireworks and together

This commit is contained in:
Dinesh Yeduguru 2024-11-12 13:07:35 -08:00
parent 25d8ab0e14
commit 8de4cee373
8 changed files with 205 additions and 86 deletions

View file

@ -11,7 +11,10 @@ from botocore.client import BaseClient
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.providers.utils.inference.model_registry import (
ModelAlias,
ModelRegistryHelper,
)
from llama_stack.apis.inference import * # noqa: F403
@ -19,19 +22,26 @@ from llama_stack.providers.remote.inference.bedrock.config import BedrockConfig
from llama_stack.providers.utils.bedrock.client import create_bedrock_client
BEDROCK_SUPPORTED_MODELS = {
"Llama3.1-8B-Instruct": "meta.llama3-1-8b-instruct-v1:0",
"Llama3.1-70B-Instruct": "meta.llama3-1-70b-instruct-v1:0",
"Llama3.1-405B-Instruct": "meta.llama3-1-405b-instruct-v1:0",
}
model_aliases = [
ModelAlias(
provider_model_id="meta.llama3-1-8b-instruct-v1:0",
aliases=["Llama3.1-8B"],
),
ModelAlias(
provider_model_id="meta.llama3-1-70b-instruct-v1:0",
aliases=["Llama3.1-70B"],
),
ModelAlias(
provider_model_id="meta.llama3-1-405b-instruct-v1:0",
aliases=["Llama3.1-405B"],
),
]
# NOTE: this is not quite tested after the recent refactors
class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
def __init__(self, config: BedrockConfig) -> None:
ModelRegistryHelper.__init__(
self, stack_to_provider_models_map=BEDROCK_SUPPORTED_MODELS
)
ModelRegistryHelper.__init__(self, model_aliases)
self._config = config
self._client = create_bedrock_client(config)