mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-01 16:24:44 +00:00
bedrock
This commit is contained in:
parent
92ee627e89
commit
d5874735ea
1 changed files with 10 additions and 5 deletions
|
@ -7,6 +7,7 @@
|
||||||
from typing import * # noqa: F403
|
from typing import * # noqa: F403
|
||||||
|
|
||||||
from botocore.client import BaseClient
|
from botocore.client import BaseClient
|
||||||
|
from llama_models.datatypes import CoreModelId
|
||||||
|
|
||||||
from llama_models.llama3.api.chat_format import ChatFormat
|
from llama_models.llama3.api.chat_format import ChatFormat
|
||||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
|
@ -25,15 +26,18 @@ from llama_stack.providers.utils.bedrock.client import create_bedrock_client
|
||||||
model_aliases = [
|
model_aliases = [
|
||||||
ModelAlias(
|
ModelAlias(
|
||||||
provider_model_id="meta.llama3-1-8b-instruct-v1:0",
|
provider_model_id="meta.llama3-1-8b-instruct-v1:0",
|
||||||
aliases=["Llama3.1-8B"],
|
aliases=["Llama3.1-8B-Instruct"],
|
||||||
|
llama_model=CoreModelId.llama3_1_8b_instruct,
|
||||||
),
|
),
|
||||||
ModelAlias(
|
ModelAlias(
|
||||||
provider_model_id="meta.llama3-1-70b-instruct-v1:0",
|
provider_model_id="meta.llama3-1-70b-instruct-v1:0",
|
||||||
aliases=["Llama3.1-70B"],
|
aliases=["Llama3.1-70B-Instruct"],
|
||||||
|
llama_model=CoreModelId.llama3_1_70b_instruct,
|
||||||
),
|
),
|
||||||
ModelAlias(
|
ModelAlias(
|
||||||
provider_model_id="meta.llama3-1-405b-instruct-v1:0",
|
provider_model_id="meta.llama3-1-405b-instruct-v1:0",
|
||||||
aliases=["Llama3.1-405B"],
|
aliases=["Llama3.1-405B-Instruct"],
|
||||||
|
llama_model=CoreModelId.llama3_1_405b_instruct,
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -308,8 +312,9 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
) -> Union[
|
) -> Union[
|
||||||
ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]
|
ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]
|
||||||
]:
|
]:
|
||||||
|
model = await self.model_store.get_model(model_id)
|
||||||
request = ChatCompletionRequest(
|
request = ChatCompletionRequest(
|
||||||
model=model_id,
|
model=model.provider_resource_id,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
tools=tools or [],
|
tools=tools or [],
|
||||||
|
@ -414,7 +419,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def _get_params_for_chat_completion(self, request: ChatCompletionRequest) -> Dict:
|
def _get_params_for_chat_completion(self, request: ChatCompletionRequest) -> Dict:
|
||||||
bedrock_model = self.map_to_provider_model(request.model)
|
bedrock_model = request.model
|
||||||
inference_config = BedrockInferenceAdapter.get_bedrock_inference_config(
|
inference_config = BedrockInferenceAdapter.get_bedrock_inference_config(
|
||||||
request.sampling_params
|
request.sampling_params
|
||||||
)
|
)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue