This commit is contained in:
Dinesh Yeduguru 2024-11-12 14:08:47 -08:00
parent 92ee627e89
commit d5874735ea

View file

@ -7,6 +7,7 @@
from typing import * # noqa: F403
from botocore.client import BaseClient
from llama_models.datatypes import CoreModelId
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.tokenizer import Tokenizer
@ -25,15 +26,18 @@ from llama_stack.providers.utils.bedrock.client import create_bedrock_client
model_aliases = [
ModelAlias(
provider_model_id="meta.llama3-1-8b-instruct-v1:0",
aliases=["Llama3.1-8B"],
aliases=["Llama3.1-8B-Instruct"],
llama_model=CoreModelId.llama3_1_8b_instruct,
),
ModelAlias(
provider_model_id="meta.llama3-1-70b-instruct-v1:0",
aliases=["Llama3.1-70B"],
aliases=["Llama3.1-70B-Instruct"],
llama_model=CoreModelId.llama3_1_70b_instruct,
),
ModelAlias(
provider_model_id="meta.llama3-1-405b-instruct-v1:0",
aliases=["Llama3.1-405B"],
aliases=["Llama3.1-405B-Instruct"],
llama_model=CoreModelId.llama3_1_405b_instruct,
),
]
@ -308,8 +312,9 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
) -> Union[
ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]
]:
model = await self.model_store.get_model(model_id)
request = ChatCompletionRequest(
model=model_id,
model=model.provider_resource_id,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
@ -414,7 +419,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
pass
def _get_params_for_chat_completion(self, request: ChatCompletionRequest) -> Dict:
bedrock_model = self.map_to_provider_model(request.model)
bedrock_model = request.model
inference_config = BedrockInferenceAdapter.get_bedrock_inference_config(
request.sampling_params
)