Inference to use provider resource id to register and validate (#428)

This PR changes the way model id gets translated to the final model name
that gets passed through the provider.
Major changes include:
1) Providers are responsible for registering an object and as part of
the registration returning the object with the correct provider specific
name of the model provider_resource_id
2) To help with the common look ups different names a new ModelLookup
class is created.



Tested all inference providers including together, fireworks, vllm,
ollama, meta reference and bedrock
This commit is contained in:
Dinesh Yeduguru 2024-11-12 20:02:00 -08:00 committed by GitHub
parent e51107e019
commit fdff24e77a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
21 changed files with 460 additions and 290 deletions

View file

@ -7,11 +7,15 @@
from typing import * # noqa: F403
from botocore.client import BaseClient
from llama_models.datatypes import CoreModelId
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.providers.utils.inference.model_registry import (
build_model_alias,
ModelRegistryHelper,
)
from llama_stack.apis.inference import * # noqa: F403
@ -19,19 +23,26 @@ from llama_stack.providers.remote.inference.bedrock.config import BedrockConfig
from llama_stack.providers.utils.bedrock.client import create_bedrock_client
BEDROCK_SUPPORTED_MODELS = {
"Llama3.1-8B-Instruct": "meta.llama3-1-8b-instruct-v1:0",
"Llama3.1-70B-Instruct": "meta.llama3-1-70b-instruct-v1:0",
"Llama3.1-405B-Instruct": "meta.llama3-1-405b-instruct-v1:0",
}
model_aliases = [
build_model_alias(
"meta.llama3-1-8b-instruct-v1:0",
CoreModelId.llama3_1_8b_instruct.value,
),
build_model_alias(
"meta.llama3-1-70b-instruct-v1:0",
CoreModelId.llama3_1_70b_instruct.value,
),
build_model_alias(
"meta.llama3-1-405b-instruct-v1:0",
CoreModelId.llama3_1_405b_instruct.value,
),
]
# NOTE: this is not quite tested after the recent refactors
class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
def __init__(self, config: BedrockConfig) -> None:
ModelRegistryHelper.__init__(
self, stack_to_provider_models_map=BEDROCK_SUPPORTED_MODELS
)
ModelRegistryHelper.__init__(self, model_aliases)
self._config = config
self._client = create_bedrock_client(config)
@ -49,7 +60,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
async def completion(
self,
model: str,
model_id: str,
content: InterleavedTextMedia,
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
@ -286,7 +297,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
async def chat_completion(
self,
model: str,
model_id: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
@ -298,8 +309,9 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
) -> Union[
ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]
]:
model = await self.model_store.get_model(model_id)
request = ChatCompletionRequest(
model=model,
model=model.provider_resource_id,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
@ -404,7 +416,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
pass
def _get_params_for_chat_completion(self, request: ChatCompletionRequest) -> Dict:
bedrock_model = self.map_to_provider_model(request.model)
bedrock_model = request.model
inference_config = BedrockInferenceAdapter.get_bedrock_inference_config(
request.sampling_params
)
@ -433,7 +445,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
async def embeddings(
self,
model: str,
model_id: str,
contents: List[InterleavedTextMedia],
) -> EmbeddingsResponse:
raise NotImplementedError()