From d5874735eada4c4a0d91b0c8d1cd639cd0941292 Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Tue, 12 Nov 2024 14:08:47 -0800 Subject: [PATCH] bedrock --- .../providers/remote/inference/bedrock/bedrock.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/llama_stack/providers/remote/inference/bedrock/bedrock.py b/llama_stack/providers/remote/inference/bedrock/bedrock.py index 2f1378696..47abff689 100644 --- a/llama_stack/providers/remote/inference/bedrock/bedrock.py +++ b/llama_stack/providers/remote/inference/bedrock/bedrock.py @@ -7,6 +7,7 @@ from typing import * # noqa: F403 from botocore.client import BaseClient +from llama_models.datatypes import CoreModelId from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.tokenizer import Tokenizer @@ -25,15 +26,18 @@ from llama_stack.providers.utils.bedrock.client import create_bedrock_client model_aliases = [ ModelAlias( provider_model_id="meta.llama3-1-8b-instruct-v1:0", - aliases=["Llama3.1-8B"], + aliases=["Llama3.1-8B-Instruct"], + llama_model=CoreModelId.llama3_1_8b_instruct, ), ModelAlias( provider_model_id="meta.llama3-1-70b-instruct-v1:0", - aliases=["Llama3.1-70B"], + aliases=["Llama3.1-70B-Instruct"], + llama_model=CoreModelId.llama3_1_70b_instruct, ), ModelAlias( provider_model_id="meta.llama3-1-405b-instruct-v1:0", - aliases=["Llama3.1-405B"], + aliases=["Llama3.1-405B-Instruct"], + llama_model=CoreModelId.llama3_1_405b_instruct, ), ] @@ -308,8 +312,9 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference): ) -> Union[ ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk] ]: + model = await self.model_store.get_model(model_id) request = ChatCompletionRequest( - model=model_id, + model=model.provider_resource_id, messages=messages, sampling_params=sampling_params, tools=tools or [], @@ -414,7 +419,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference): pass def _get_params_for_chat_completion(self, request: ChatCompletionRequest) -> Dict: - bedrock_model = self.map_to_provider_model(request.model) + bedrock_model = request.model inference_config = BedrockInferenceAdapter.get_bedrock_inference_config( request.sampling_params )