From 8de4cee3735e46507fc4eec87b17c2b0871ed446 Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Tue, 12 Nov 2024 13:07:35 -0800 Subject: [PATCH] working fireworks and together --- llama_stack/distribution/routers/routers.py | 9 +-- .../distribution/routers/routing_tables.py | 27 +++++-- .../remote/inference/bedrock/bedrock.py | 28 ++++--- .../remote/inference/databricks/databricks.py | 2 +- .../remote/inference/fireworks/fireworks.py | 79 ++++++++++++++----- .../remote/inference/together/together.py | 73 ++++++++++++----- .../utils/inference/model_registry.py | 60 +++++++++----- .../utils/inference/prompt_adapter.py | 13 +-- 8 files changed, 205 insertions(+), 86 deletions(-) diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index 7d4e43f60..5a62b6d64 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -105,9 +105,8 @@ class InferenceRouter(Inference): stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> AsyncGenerator: - model = await self.routing_table.get_model(model_id) params = dict( - model_id=model.provider_resource_id, + model_id=model_id, messages=messages, sampling_params=sampling_params, tools=tools or [], @@ -132,10 +131,9 @@ class InferenceRouter(Inference): stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> AsyncGenerator: - model = await self.routing_table.get_model(model_id) provider = self.routing_table.get_provider_impl(model_id) params = dict( - model_id=model.provider_resource_id, + model_id=model_id, content=content, sampling_params=sampling_params, response_format=response_format, @@ -152,9 +150,8 @@ class InferenceRouter(Inference): model_id: str, contents: List[InterleavedTextMedia], ) -> EmbeddingsResponse: - model = await self.routing_table.get_model(model_id) return await self.routing_table.get_provider_impl(model_id).embeddings( - model_id=model.provider_resource_id, + model_id=model_id, contents=contents, ) diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index d6fb5d662..249d3a144 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -28,7 +28,9 @@ def get_impl_api(p: Any) -> Api: return p.__provider_spec__.api -async def register_object_with_provider(obj: RoutableObject, p: Any) -> None: +# TODO: this should return the registered object for all APIs +async def register_object_with_provider(obj: RoutableObject, p: Any) -> RoutableObject: + api = get_impl_api(p) if obj.provider_id == "remote": @@ -42,7 +44,7 @@ async def register_object_with_provider(obj: RoutableObject, p: Any) -> None: obj.provider_id = "" if api == Api.inference: - await p.register_model(obj) + return await p.register_model(obj) elif api == Api.safety: await p.register_shield(obj) elif api == Api.memory: @@ -167,7 +169,9 @@ class CommonRoutingTableImpl(RoutingTable): assert len(objects) == 1 return objects[0] - async def register_object(self, obj: RoutableObjectWithProvider): + async def register_object( + self, obj: RoutableObjectWithProvider + ) -> RoutableObjectWithProvider: # Get existing objects from registry existing_objects = await self.dist_registry.get(obj.type, obj.identifier) @@ -177,7 +181,7 @@ class CommonRoutingTableImpl(RoutingTable): print( f"`{obj.identifier}` already registered with `{existing_obj.provider_id}`" ) - return + return existing_obj # if provider_id is not specified, pick an arbitrary one from existing entries if not obj.provider_id and len(self.impls_by_provider_id) > 0: @@ -188,8 +192,15 @@ class CommonRoutingTableImpl(RoutingTable): p = self.impls_by_provider_id[obj.provider_id] - await register_object_with_provider(obj, p) - await self.dist_registry.register(obj) + registered_obj = await register_object_with_provider(obj, p) + # TODO: This needs to be fixed for all APIs once they return the registered object + if obj.type == ResourceType.model.value: + await self.dist_registry.register(registered_obj) + return registered_obj + + else: + await self.dist_registry.register(obj) + return obj async def get_all_with_type(self, type: str) -> List[RoutableObjectWithProvider]: objs = await self.dist_registry.get_all() @@ -228,8 +239,8 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models): provider_id=provider_id, metadata=metadata, ) - await self.register_object(model) - return model + registered_model = await self.register_object(model) + return registered_model class ShieldsRoutingTable(CommonRoutingTableImpl, Shields): diff --git a/llama_stack/providers/remote/inference/bedrock/bedrock.py b/llama_stack/providers/remote/inference/bedrock/bedrock.py index 7900e096d..2f1378696 100644 --- a/llama_stack/providers/remote/inference/bedrock/bedrock.py +++ b/llama_stack/providers/remote/inference/bedrock/bedrock.py @@ -11,7 +11,10 @@ from botocore.client import BaseClient from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.tokenizer import Tokenizer -from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper +from llama_stack.providers.utils.inference.model_registry import ( + ModelAlias, + ModelRegistryHelper, +) from llama_stack.apis.inference import * # noqa: F403 @@ -19,19 +22,26 @@ from llama_stack.providers.remote.inference.bedrock.config import BedrockConfig from llama_stack.providers.utils.bedrock.client import create_bedrock_client -BEDROCK_SUPPORTED_MODELS = { - "Llama3.1-8B-Instruct": "meta.llama3-1-8b-instruct-v1:0", - "Llama3.1-70B-Instruct": "meta.llama3-1-70b-instruct-v1:0", - "Llama3.1-405B-Instruct": "meta.llama3-1-405b-instruct-v1:0", -} +model_aliases = [ + ModelAlias( + provider_model_id="meta.llama3-1-8b-instruct-v1:0", + aliases=["Llama3.1-8B"], + ), + ModelAlias( + provider_model_id="meta.llama3-1-70b-instruct-v1:0", + aliases=["Llama3.1-70B"], + ), + ModelAlias( + provider_model_id="meta.llama3-1-405b-instruct-v1:0", + aliases=["Llama3.1-405B"], + ), +] # NOTE: this is not quite tested after the recent refactors class BedrockInferenceAdapter(ModelRegistryHelper, Inference): def __init__(self, config: BedrockConfig) -> None: - ModelRegistryHelper.__init__( - self, stack_to_provider_models_map=BEDROCK_SUPPORTED_MODELS - ) + ModelRegistryHelper.__init__(self, model_aliases) self._config = config self._client = create_bedrock_client(config) diff --git a/llama_stack/providers/remote/inference/databricks/databricks.py b/llama_stack/providers/remote/inference/databricks/databricks.py index f12ecb7f5..8e1f7693a 100644 --- a/llama_stack/providers/remote/inference/databricks/databricks.py +++ b/llama_stack/providers/remote/inference/databricks/databricks.py @@ -37,7 +37,7 @@ DATABRICKS_SUPPORTED_MODELS = { class DatabricksInferenceAdapter(ModelRegistryHelper, Inference): def __init__(self, config: DatabricksImplConfig) -> None: ModelRegistryHelper.__init__( - self, stack_to_provider_models_map=DATABRICKS_SUPPORTED_MODELS + self, provider_to_common_model_aliases_map=DATABRICKS_SUPPORTED_MODELS ) self.config = config self.formatter = ChatFormat(Tokenizer.get_instance()) diff --git a/llama_stack/providers/remote/inference/fireworks/fireworks.py b/llama_stack/providers/remote/inference/fireworks/fireworks.py index 67bf1cf47..ce9639cbd 100644 --- a/llama_stack/providers/remote/inference/fireworks/fireworks.py +++ b/llama_stack/providers/remote/inference/fireworks/fireworks.py @@ -7,14 +7,17 @@ from typing import AsyncGenerator from fireworks.client import Fireworks +from llama_models.datatypes import CoreModelId from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.datatypes import Message from llama_models.llama3.api.tokenizer import Tokenizer - from llama_stack.apis.inference import * # noqa: F403 from llama_stack.distribution.request_headers import NeedsRequestProviderData -from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper +from llama_stack.providers.utils.inference.model_registry import ( + ModelAlias, + ModelRegistryHelper, +) from llama_stack.providers.utils.inference.openai_compat import ( get_sampling_options, process_chat_completion_response, @@ -31,25 +34,61 @@ from llama_stack.providers.utils.inference.prompt_adapter import ( from .config import FireworksImplConfig -FIREWORKS_SUPPORTED_MODELS = { - "Llama3.1-8B-Instruct": "fireworks/llama-v3p1-8b-instruct", - "Llama3.1-70B-Instruct": "fireworks/llama-v3p1-70b-instruct", - "Llama3.1-405B-Instruct": "fireworks/llama-v3p1-405b-instruct", - "Llama3.2-1B-Instruct": "fireworks/llama-v3p2-1b-instruct", - "Llama3.2-3B-Instruct": "fireworks/llama-v3p2-3b-instruct", - "Llama3.2-11B-Vision-Instruct": "fireworks/llama-v3p2-11b-vision-instruct", - "Llama3.2-90B-Vision-Instruct": "fireworks/llama-v3p2-90b-vision-instruct", - "Llama-Guard-3-8B": "fireworks/llama-guard-3-8b", -} + +model_aliases = [ + ModelAlias( + provider_model_id="fireworks/llama-v3p1-8b-instruct", + aliases=["Llama3.1-8B-Instruct"], + llama_model=CoreModelId.llama3_1_8b_instruct.value, + ), + ModelAlias( + provider_model_id="fireworks/llama-v3p1-70b-instruct", + aliases=["Llama3.1-70B-Instruct"], + llama_model=CoreModelId.llama3_1_70b_instruct.value, + ), + ModelAlias( + provider_model_id="fireworks/llama-v3p1-405b-instruct", + aliases=["Llama3.1-405B-Instruct"], + llama_model=CoreModelId.llama3_1_405b_instruct.value, + ), + ModelAlias( + provider_model_id="fireworks/llama-v3p2-1b-instruct", + aliases=["Llama3.2-1B-Instruct"], + llama_model=CoreModelId.llama3_2_3b_instruct.value, + ), + ModelAlias( + provider_model_id="fireworks/llama-v3p2-3b-instruct", + aliases=["Llama3.2-3B-Instruct"], + llama_model=CoreModelId.llama3_2_11b_vision_instruct.value, + ), + ModelAlias( + provider_model_id="fireworks/llama-v3p2-11b-vision-instruct", + aliases=["Llama3.2-11B-Vision-Instruct"], + llama_model=CoreModelId.llama3_2_11b_vision_instruct.value, + ), + ModelAlias( + provider_model_id="fireworks/llama-v3p2-90b-vision-instruct", + aliases=["Llama3.2-90B-Vision-Instruct"], + llama_model=CoreModelId.llama3_2_90b_vision_instruct.value, + ), + ModelAlias( + provider_model_id="fireworks/llama-guard-3-8b", + aliases=["Llama-Guard-3-8B"], + llama_model=CoreModelId.llama_guard_3_8b.value, + ), + ModelAlias( + provider_model_id="fireworks/llama-guard-3-11b-vision", + aliases=["Llama-Guard-3-11B-Vision"], + llama_model=CoreModelId.llama_guard_3_11b_vision.value, + ), +] class FireworksInferenceAdapter( ModelRegistryHelper, Inference, NeedsRequestProviderData ): def __init__(self, config: FireworksImplConfig) -> None: - ModelRegistryHelper.__init__( - self, stack_to_provider_models_map=FIREWORKS_SUPPORTED_MODELS - ) + ModelRegistryHelper.__init__(self, model_aliases) self.config = config self.formatter = ChatFormat(Tokenizer.get_instance()) @@ -81,8 +120,9 @@ class FireworksInferenceAdapter( stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> AsyncGenerator: + model = await self.model_store.get_model(model_id) request = CompletionRequest( - model=model_id, + model=model.provider_resource_id, content=content, sampling_params=sampling_params, response_format=response_format, @@ -148,8 +188,9 @@ class FireworksInferenceAdapter( stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> AsyncGenerator: + model = await self.model_store.get_model(model_id) request = ChatCompletionRequest( - model=model_id, + model=model.provider_resource_id, messages=messages, sampling_params=sampling_params, tools=tools or [], @@ -207,7 +248,7 @@ class FireworksInferenceAdapter( ] else: input_dict["prompt"] = chat_completion_request_to_prompt( - request, self.formatter + request, self.get_llama_model(request.model), self.formatter ) else: assert ( @@ -221,7 +262,7 @@ class FireworksInferenceAdapter( input_dict["prompt"] = input_dict["prompt"][len("<|begin_of_text|>") :] return { - "model": self.map_to_provider_model(request.model), + "model": request.model, **input_dict, "stream": request.stream, **self._build_options(request.sampling_params, request.response_format), diff --git a/llama_stack/providers/remote/inference/together/together.py b/llama_stack/providers/remote/inference/together/together.py index 1b04ae556..75f93f64f 100644 --- a/llama_stack/providers/remote/inference/together/together.py +++ b/llama_stack/providers/remote/inference/together/together.py @@ -6,6 +6,8 @@ from typing import AsyncGenerator +from llama_models.datatypes import CoreModelId + from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.datatypes import Message @@ -15,7 +17,10 @@ from together import Together from llama_stack.apis.inference import * # noqa: F403 from llama_stack.distribution.request_headers import NeedsRequestProviderData -from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper +from llama_stack.providers.utils.inference.model_registry import ( + ModelAlias, + ModelRegistryHelper, +) from llama_stack.providers.utils.inference.openai_compat import ( get_sampling_options, process_chat_completion_response, @@ -33,25 +38,55 @@ from llama_stack.providers.utils.inference.prompt_adapter import ( from .config import TogetherImplConfig -TOGETHER_SUPPORTED_MODELS = { - "Llama3.1-8B-Instruct": "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo", - "Llama3.1-70B-Instruct": "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo", - "Llama3.1-405B-Instruct": "meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo", - "Llama3.2-3B-Instruct": "meta-llama/Llama-3.2-3B-Instruct-Turbo", - "Llama3.2-11B-Vision-Instruct": "meta-llama/Llama-3.2-11B-Vision-Instruct-Turbo", - "Llama3.2-90B-Vision-Instruct": "meta-llama/Llama-3.2-90B-Vision-Instruct-Turbo", - "Llama-Guard-3-8B": "meta-llama/Meta-Llama-Guard-3-8B", - "Llama-Guard-3-11B-Vision": "meta-llama/Llama-Guard-3-11B-Vision-Turbo", -} +model_aliases = [ + ModelAlias( + provider_model_id="meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo", + aliases=["Llama3.1-8B-Instruct"], + llama_model=CoreModelId.llama3_1_8b_instruct.value, + ), + ModelAlias( + provider_model_id="meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo", + aliases=["Llama3.1-70B-Instruct"], + llama_model=CoreModelId.llama3_1_70b_instruct.value, + ), + ModelAlias( + provider_model_id="meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo", + aliases=["Llama3.1-405B-Instruct"], + llama_model=CoreModelId.llama3_1_405b_instruct.value, + ), + ModelAlias( + provider_model_id="meta-llama/Llama-3.2-3B-Instruct-Turbo", + aliases=["Llama3.2-3B-Instruct"], + llama_model=CoreModelId.llama3_2_3b_instruct.value, + ), + ModelAlias( + provider_model_id="meta-llama/Llama-3.2-11B-Vision-Instruct-Turbo", + aliases=["Llama3.2-11B-Vision-Instruct"], + llama_model=CoreModelId.llama3_2_11b_vision_instruct.value, + ), + ModelAlias( + provider_model_id="meta-llama/Llama-3.2-90B-Vision-Instruct-Turbo", + aliases=["Llama3.2-90B-Vision-Instruct"], + llama_model=CoreModelId.llama3_2_90b_vision_instruct.value, + ), + ModelAlias( + provider_model_id="meta-llama/Meta-Llama-Guard-3-8B", + aliases=["Llama-Guard-3-8B"], + llama_model=CoreModelId.llama_guard_3_8b.value, + ), + ModelAlias( + provider_model_id="meta-llama/Llama-Guard-3-11B-Vision-Turbo", + aliases=["Llama-Guard-3-11B-Vision"], + llama_model=CoreModelId.llama_guard_3_11b_vision.value, + ), +] class TogetherInferenceAdapter( ModelRegistryHelper, Inference, NeedsRequestProviderData ): def __init__(self, config: TogetherImplConfig) -> None: - ModelRegistryHelper.__init__( - self, stack_to_provider_models_map=TOGETHER_SUPPORTED_MODELS - ) + ModelRegistryHelper.__init__(self, model_aliases) self.config = config self.formatter = ChatFormat(Tokenizer.get_instance()) @@ -70,8 +105,9 @@ class TogetherInferenceAdapter( stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> AsyncGenerator: + model = await self.model_store.get_model(model_id) request = CompletionRequest( - model=model_id, + model=model.provider_resource_id, content=content, sampling_params=sampling_params, response_format=response_format, @@ -145,8 +181,9 @@ class TogetherInferenceAdapter( stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> AsyncGenerator: + model = await self.model_store.get_model(model_id) request = ChatCompletionRequest( - model=model_id, + model=model.provider_resource_id, messages=messages, sampling_params=sampling_params, tools=tools or [], @@ -204,7 +241,7 @@ class TogetherInferenceAdapter( ] else: input_dict["prompt"] = chat_completion_request_to_prompt( - request, self.formatter + request, self.get_llama_model(request.model), self.formatter ) else: assert ( @@ -213,7 +250,7 @@ class TogetherInferenceAdapter( input_dict["prompt"] = completion_request_to_prompt(request, self.formatter) return { - "model": self.map_to_provider_model(request.model), + "model": request.model, **input_dict, "stream": request.stream, **self._build_options(request.sampling_params, request.response_format), diff --git a/llama_stack/providers/utils/inference/model_registry.py b/llama_stack/providers/utils/inference/model_registry.py index bdc5af0f9..b3401d8f5 100644 --- a/llama_stack/providers/utils/inference/model_registry.py +++ b/llama_stack/providers/utils/inference/model_registry.py @@ -4,32 +4,54 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Dict - -from llama_models.sku_list import resolve_model +from collections import namedtuple +from typing import List from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate +ModelAlias = namedtuple("ModelAlias", ["provider_model_id", "aliases", "llama_model"]) + + +class ModelLookup: + def __init__( + self, + model_aliases: List[ModelAlias], + ): + self.alias_to_provider_id_map = {} + self.provider_id_to_llama_model_map = {} + for alias_obj in model_aliases: + for alias in alias_obj.aliases: + self.alias_to_provider_id_map[alias] = alias_obj.provider_model_id + # also add a mapping from provider model id to itself for easy lookup + self.alias_to_provider_id_map[alias_obj.provider_model_id] = ( + alias_obj.provider_model_id + ) + self.provider_id_to_llama_model_map[alias_obj.provider_model_id] = ( + alias_obj.llama_model + ) + + def get_provider_model_id(self, identifier: str) -> str: + if identifier in self.alias_to_provider_id_map: + return self.alias_to_provider_id_map[identifier] + else: + raise ValueError(f"Unknown model: `{identifier}`") + class ModelRegistryHelper(ModelsProtocolPrivate): - def __init__(self, stack_to_provider_models_map: Dict[str, str]): - self.stack_to_provider_models_map = stack_to_provider_models_map + def __init__(self, model_aliases: List[ModelAlias]): + self.model_lookup = ModelLookup(model_aliases) - def map_to_provider_model(self, identifier: str) -> str: - model = resolve_model(identifier) - if not model: - raise ValueError(f"Unknown model: `{identifier}`") + def get_llama_model(self, provider_model_id: str) -> str: + return self.model_lookup.provider_id_to_llama_model_map[provider_model_id] - if identifier not in self.stack_to_provider_models_map: - raise ValueError( - f"Model {identifier} not found in map {self.stack_to_provider_models_map}" - ) + async def register_model(self, model: Model) -> Model: + provider_model_id = self.model_lookup.get_provider_model_id( + model.provider_resource_id + ) + if not provider_model_id: + raise ValueError(f"Unknown model: `{model.provider_resource_id}`") - return self.stack_to_provider_models_map[identifier] + model.provider_resource_id = provider_model_id - async def register_model(self, model: Model) -> None: - if model.provider_resource_id not in self.stack_to_provider_models_map: - raise ValueError( - f"Unsupported model {model.provider_resource_id}. Supported models: {self.stack_to_provider_models_map.keys()}" - ) + return model diff --git a/llama_stack/providers/utils/inference/prompt_adapter.py b/llama_stack/providers/utils/inference/prompt_adapter.py index 45e43c898..2df04664f 100644 --- a/llama_stack/providers/utils/inference/prompt_adapter.py +++ b/llama_stack/providers/utils/inference/prompt_adapter.py @@ -147,17 +147,17 @@ def augment_content_with_response_format_prompt(response_format, content): def chat_completion_request_to_prompt( - request: ChatCompletionRequest, formatter: ChatFormat + request: ChatCompletionRequest, llama_model: str, formatter: ChatFormat ) -> str: - messages = chat_completion_request_to_messages(request) + messages = chat_completion_request_to_messages(request, llama_model) model_input = formatter.encode_dialog_prompt(messages) return formatter.tokenizer.decode(model_input.tokens) def chat_completion_request_to_model_input_info( - request: ChatCompletionRequest, formatter: ChatFormat + request: ChatCompletionRequest, llama_model: str, formatter: ChatFormat ) -> Tuple[str, int]: - messages = chat_completion_request_to_messages(request) + messages = chat_completion_request_to_messages(request, llama_model) model_input = formatter.encode_dialog_prompt(messages) return ( formatter.tokenizer.decode(model_input.tokens), @@ -167,14 +167,15 @@ def chat_completion_request_to_model_input_info( def chat_completion_request_to_messages( request: ChatCompletionRequest, + llama_model: str, ) -> List[Message]: """Reads chat completion request and augments the messages to handle tools. For eg. for llama_3_1, add system message with the appropriate tools or add user messsage for custom tools, etc. """ - model = resolve_model(request.model) + model = resolve_model(llama_model) if model is None: - cprint(f"Could not resolve model {request.model}", color="red") + cprint(f"Could not resolve model {llama_model}", color="red") return request.messages if model.descriptor() not in supported_inference_models():