From fdff24e77a43636c78240a914f830c45c331e019 Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Tue, 12 Nov 2024 20:02:00 -0800 Subject: [PATCH] Inference to use provider resource id to register and validate (#428) This PR changes the way model id gets translated to the final model name that gets passed through the provider. Major changes include: 1) Providers are responsible for registering an object and as part of the registration returning the object with the correct provider specific name of the model provider_resource_id 2) To help with the common look ups different names a new ModelLookup class is created. Tested all inference providers including together, fireworks, vllm, ollama, meta reference and bedrock --- docs/resources/llama-stack-spec.html | 86 ++++++++-------- docs/resources/llama-stack-spec.yaml | 42 ++++---- docs/source/getting_started/index.md | 2 +- llama_stack/apis/inference/inference.py | 6 +- llama_stack/distribution/routers/routers.py | 18 ++-- .../distribution/routers/routing_tables.py | 27 +++-- .../inline/eval/meta_reference/eval.py | 2 +- .../inference/meta_reference/generation.py | 14 ++- .../inference/meta_reference/inference.py | 33 ++++--- .../providers/inline/inference/vllm/vllm.py | 10 +- .../scoring_fn/llm_as_judge_scoring_fn.py | 2 +- .../remote/inference/bedrock/bedrock.py | 40 +++++--- .../remote/inference/databricks/databricks.py | 30 ++++-- .../remote/inference/fireworks/fireworks.py | 76 +++++++++----- .../remote/inference/ollama/ollama.py | 99 ++++++++++--------- .../remote/inference/together/together.py | 71 +++++++++---- .../providers/remote/inference/vllm/vllm.py | 71 +++++++------ .../providers/tests/inference/fixtures.py | 27 ++++- .../tests/inference/test_text_inference.py | 16 +-- .../utils/inference/model_registry.py | 65 ++++++++---- .../utils/inference/prompt_adapter.py | 13 +-- 21 files changed, 460 insertions(+), 290 deletions(-) diff --git a/docs/resources/llama-stack-spec.html b/docs/resources/llama-stack-spec.html index 7ef4ece21..f87cb5590 100644 --- a/docs/resources/llama-stack-spec.html +++ b/docs/resources/llama-stack-spec.html @@ -21,7 +21,7 @@ "info": { "title": "[DRAFT] Llama Stack Specification", "version": "0.0.1", - "description": "This is the specification of the llama stack that provides\n a set of endpoints and their corresponding interfaces that are tailored to\n best leverage Llama Models. The specification is still in draft and subject to change.\n Generated at 2024-11-12 11:39:48.665782" + "description": "This is the specification of the llama stack that provides\n a set of endpoints and their corresponding interfaces that are tailored to\n best leverage Llama Models. The specification is still in draft and subject to change.\n Generated at 2024-11-12 15:47:15.607543" }, "servers": [ { @@ -2856,7 +2856,7 @@ "ChatCompletionRequest": { "type": "object", "properties": { - "model": { + "model_id": { "type": "string" }, "messages": { @@ -2993,7 +2993,7 @@ }, "additionalProperties": false, "required": [ - "model", + "model_id", "messages" ] }, @@ -3120,7 +3120,7 @@ "CompletionRequest": { "type": "object", "properties": { - "model": { + "model_id": { "type": "string" }, "content": { @@ -3249,7 +3249,7 @@ }, "additionalProperties": false, "required": [ - "model", + "model_id", "content" ] }, @@ -4552,7 +4552,7 @@ "EmbeddingsRequest": { "type": "object", "properties": { - "model": { + "model_id": { "type": "string" }, "contents": { @@ -4584,7 +4584,7 @@ }, "additionalProperties": false, "required": [ - "model", + "model_id", "contents" ] }, @@ -7837,34 +7837,10 @@ ], "tags": [ { - "name": "MemoryBanks" + "name": "Safety" }, { - "name": "BatchInference" - }, - { - "name": "Agents" - }, - { - "name": "Inference" - }, - { - "name": "DatasetIO" - }, - { - "name": "Eval" - }, - { - "name": "Models" - }, - { - "name": "PostTraining" - }, - { - "name": "ScoringFunctions" - }, - { - "name": "Datasets" + "name": "EvalTasks" }, { "name": "Shields" @@ -7872,15 +7848,6 @@ { "name": "Telemetry" }, - { - "name": "Inspect" - }, - { - "name": "Safety" - }, - { - "name": "SyntheticDataGeneration" - }, { "name": "Memory" }, @@ -7888,7 +7855,40 @@ "name": "Scoring" }, { - "name": "EvalTasks" + "name": "ScoringFunctions" + }, + { + "name": "SyntheticDataGeneration" + }, + { + "name": "Models" + }, + { + "name": "Agents" + }, + { + "name": "MemoryBanks" + }, + { + "name": "DatasetIO" + }, + { + "name": "Inference" + }, + { + "name": "Datasets" + }, + { + "name": "PostTraining" + }, + { + "name": "BatchInference" + }, + { + "name": "Eval" + }, + { + "name": "Inspect" }, { "name": "BuiltinTool", diff --git a/docs/resources/llama-stack-spec.yaml b/docs/resources/llama-stack-spec.yaml index b86c0df61..87268ff47 100644 --- a/docs/resources/llama-stack-spec.yaml +++ b/docs/resources/llama-stack-spec.yaml @@ -396,7 +396,7 @@ components: - $ref: '#/components/schemas/ToolResponseMessage' - $ref: '#/components/schemas/CompletionMessage' type: array - model: + model_id: type: string response_format: oneOf: @@ -453,7 +453,7 @@ components: $ref: '#/components/schemas/ToolDefinition' type: array required: - - model + - model_id - messages type: object ChatCompletionResponse: @@ -577,7 +577,7 @@ components: default: 0 type: integer type: object - model: + model_id: type: string response_format: oneOf: @@ -626,7 +626,7 @@ components: stream: type: boolean required: - - model + - model_id - content type: object CompletionResponse: @@ -903,10 +903,10 @@ components: - $ref: '#/components/schemas/ImageMedia' type: array type: array - model: + model_id: type: string required: - - model + - model_id - contents type: object EmbeddingsResponse: @@ -3384,7 +3384,7 @@ info: description: "This is the specification of the llama stack that provides\n \ \ a set of endpoints and their corresponding interfaces that are tailored\ \ to\n best leverage Llama Models. The specification is still in\ - \ draft and subject to change.\n Generated at 2024-11-12 11:39:48.665782" + \ draft and subject to change.\n Generated at 2024-11-12 15:47:15.607543" title: '[DRAFT] Llama Stack Specification' version: 0.0.1 jsonSchemaDialect: https://json-schema.org/draft/2020-12/schema @@ -4748,24 +4748,24 @@ security: servers: - url: http://any-hosted-llama-stack.com tags: -- name: MemoryBanks -- name: BatchInference -- name: Agents -- name: Inference -- name: DatasetIO -- name: Eval -- name: Models -- name: PostTraining -- name: ScoringFunctions -- name: Datasets +- name: Safety +- name: EvalTasks - name: Shields - name: Telemetry -- name: Inspect -- name: Safety -- name: SyntheticDataGeneration - name: Memory - name: Scoring -- name: EvalTasks +- name: ScoringFunctions +- name: SyntheticDataGeneration +- name: Models +- name: Agents +- name: MemoryBanks +- name: DatasetIO +- name: Inference +- name: Datasets +- name: PostTraining +- name: BatchInference +- name: Eval +- name: Inspect - description: name: BuiltinTool - description: EmbeddingsResponse: ... diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index 220dfdb56..5a62b6d64 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -95,7 +95,7 @@ class InferenceRouter(Inference): async def chat_completion( self, - model: str, + model_id: str, messages: List[Message], sampling_params: Optional[SamplingParams] = SamplingParams(), response_format: Optional[ResponseFormat] = None, @@ -106,7 +106,7 @@ class InferenceRouter(Inference): logprobs: Optional[LogProbConfig] = None, ) -> AsyncGenerator: params = dict( - model=model, + model_id=model_id, messages=messages, sampling_params=sampling_params, tools=tools or [], @@ -116,7 +116,7 @@ class InferenceRouter(Inference): stream=stream, logprobs=logprobs, ) - provider = self.routing_table.get_provider_impl(model) + provider = self.routing_table.get_provider_impl(model_id) if stream: return (chunk async for chunk in await provider.chat_completion(**params)) else: @@ -124,16 +124,16 @@ class InferenceRouter(Inference): async def completion( self, - model: str, + model_id: str, content: InterleavedTextMedia, sampling_params: Optional[SamplingParams] = SamplingParams(), response_format: Optional[ResponseFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> AsyncGenerator: - provider = self.routing_table.get_provider_impl(model) + provider = self.routing_table.get_provider_impl(model_id) params = dict( - model=model, + model_id=model_id, content=content, sampling_params=sampling_params, response_format=response_format, @@ -147,11 +147,11 @@ class InferenceRouter(Inference): async def embeddings( self, - model: str, + model_id: str, contents: List[InterleavedTextMedia], ) -> EmbeddingsResponse: - return await self.routing_table.get_provider_impl(model).embeddings( - model=model, + return await self.routing_table.get_provider_impl(model_id).embeddings( + model_id=model_id, contents=contents, ) diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index d6fb5d662..249d3a144 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -28,7 +28,9 @@ def get_impl_api(p: Any) -> Api: return p.__provider_spec__.api -async def register_object_with_provider(obj: RoutableObject, p: Any) -> None: +# TODO: this should return the registered object for all APIs +async def register_object_with_provider(obj: RoutableObject, p: Any) -> RoutableObject: + api = get_impl_api(p) if obj.provider_id == "remote": @@ -42,7 +44,7 @@ async def register_object_with_provider(obj: RoutableObject, p: Any) -> None: obj.provider_id = "" if api == Api.inference: - await p.register_model(obj) + return await p.register_model(obj) elif api == Api.safety: await p.register_shield(obj) elif api == Api.memory: @@ -167,7 +169,9 @@ class CommonRoutingTableImpl(RoutingTable): assert len(objects) == 1 return objects[0] - async def register_object(self, obj: RoutableObjectWithProvider): + async def register_object( + self, obj: RoutableObjectWithProvider + ) -> RoutableObjectWithProvider: # Get existing objects from registry existing_objects = await self.dist_registry.get(obj.type, obj.identifier) @@ -177,7 +181,7 @@ class CommonRoutingTableImpl(RoutingTable): print( f"`{obj.identifier}` already registered with `{existing_obj.provider_id}`" ) - return + return existing_obj # if provider_id is not specified, pick an arbitrary one from existing entries if not obj.provider_id and len(self.impls_by_provider_id) > 0: @@ -188,8 +192,15 @@ class CommonRoutingTableImpl(RoutingTable): p = self.impls_by_provider_id[obj.provider_id] - await register_object_with_provider(obj, p) - await self.dist_registry.register(obj) + registered_obj = await register_object_with_provider(obj, p) + # TODO: This needs to be fixed for all APIs once they return the registered object + if obj.type == ResourceType.model.value: + await self.dist_registry.register(registered_obj) + return registered_obj + + else: + await self.dist_registry.register(obj) + return obj async def get_all_with_type(self, type: str) -> List[RoutableObjectWithProvider]: objs = await self.dist_registry.get_all() @@ -228,8 +239,8 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models): provider_id=provider_id, metadata=metadata, ) - await self.register_object(model) - return model + registered_model = await self.register_object(model) + return registered_model class ShieldsRoutingTable(CommonRoutingTableImpl, Shields): diff --git a/llama_stack/providers/inline/eval/meta_reference/eval.py b/llama_stack/providers/inline/eval/meta_reference/eval.py index ba2fc7c95..58241eb42 100644 --- a/llama_stack/providers/inline/eval/meta_reference/eval.py +++ b/llama_stack/providers/inline/eval/meta_reference/eval.py @@ -150,7 +150,7 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate): messages.append(candidate.system_message) messages += input_messages response = await self.inference_api.chat_completion( - model=candidate.model, + model_id=candidate.model, messages=messages, sampling_params=candidate.sampling_params, ) diff --git a/llama_stack/providers/inline/inference/meta_reference/generation.py b/llama_stack/providers/inline/inference/meta_reference/generation.py index 2f296c7c2..38c982473 100644 --- a/llama_stack/providers/inline/inference/meta_reference/generation.py +++ b/llama_stack/providers/inline/inference/meta_reference/generation.py @@ -86,6 +86,7 @@ class Llama: and loads the pre-trained model and tokenizer. """ model = resolve_model(config.model) + llama_model = model.core_model_id.value if not torch.distributed.is_initialized(): torch.distributed.init_process_group("nccl") @@ -186,13 +187,20 @@ class Llama: model.load_state_dict(state_dict, strict=False) print(f"Loaded in {time.time() - start_time:.2f} seconds") - return Llama(model, tokenizer, model_args) + return Llama(model, tokenizer, model_args, llama_model) - def __init__(self, model: Transformer, tokenizer: Tokenizer, args: ModelArgs): + def __init__( + self, + model: Transformer, + tokenizer: Tokenizer, + args: ModelArgs, + llama_model: str, + ): self.args = args self.model = model self.tokenizer = tokenizer self.formatter = ChatFormat(tokenizer) + self.llama_model = llama_model @torch.inference_mode() def generate( @@ -369,7 +377,7 @@ class Llama: self, request: ChatCompletionRequest, ) -> Generator: - messages = chat_completion_request_to_messages(request) + messages = chat_completion_request_to_messages(request, self.llama_model) sampling_params = request.sampling_params max_gen_len = sampling_params.max_tokens diff --git a/llama_stack/providers/inline/inference/meta_reference/inference.py b/llama_stack/providers/inline/inference/meta_reference/inference.py index 2fdc8f2d5..4f5c0c8c2 100644 --- a/llama_stack/providers/inline/inference/meta_reference/inference.py +++ b/llama_stack/providers/inline/inference/meta_reference/inference.py @@ -11,9 +11,11 @@ from typing import AsyncGenerator, List from llama_models.sku_list import resolve_model from llama_models.llama3.api.datatypes import * # noqa: F403 -from llama_stack.apis.inference import * # noqa: F403 -from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate +from llama_stack.providers.utils.inference.model_registry import build_model_alias +from llama_stack.apis.inference import * # noqa: F403 +from llama_stack.providers.datatypes import ModelsProtocolPrivate +from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper from llama_stack.providers.utils.inference.prompt_adapter import ( convert_image_media_to_url, request_has_media, @@ -28,10 +30,19 @@ from .model_parallel import LlamaModelParallelGenerator SEMAPHORE = asyncio.Semaphore(1) -class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate): +class MetaReferenceInferenceImpl(Inference, ModelRegistryHelper, ModelsProtocolPrivate): def __init__(self, config: MetaReferenceInferenceConfig) -> None: self.config = config model = resolve_model(config.model) + ModelRegistryHelper.__init__( + self, + [ + build_model_alias( + model.descriptor(), + model.core_model_id.value, + ) + ], + ) if model is None: raise RuntimeError(f"Unknown model: {config.model}, Run `llama model list`") self.model = model @@ -45,12 +56,6 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate): else: self.generator = Llama.build(self.config) - async def register_model(self, model: Model) -> None: - if model.identifier != self.model.descriptor(): - raise ValueError( - f"Model mismatch: {model.identifier} != {self.model.descriptor()}" - ) - async def shutdown(self) -> None: if self.config.create_distributed_process_group: self.generator.stop() @@ -68,7 +73,7 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate): async def completion( self, - model: str, + model_id: str, content: InterleavedTextMedia, sampling_params: Optional[SamplingParams] = SamplingParams(), response_format: Optional[ResponseFormat] = None, @@ -79,7 +84,7 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate): assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}" request = CompletionRequest( - model=model, + model=model_id, content=content, sampling_params=sampling_params, response_format=response_format, @@ -186,7 +191,7 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate): async def chat_completion( self, - model: str, + model_id: str, messages: List[Message], sampling_params: Optional[SamplingParams] = SamplingParams(), response_format: Optional[ResponseFormat] = None, @@ -201,7 +206,7 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate): # wrapper request to make it easier to pass around (internal only, not exposed to API) request = ChatCompletionRequest( - model=model, + model=model_id, messages=messages, sampling_params=sampling_params, tools=tools or [], @@ -386,7 +391,7 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate): async def embeddings( self, - model: str, + model_id: str, contents: List[InterleavedTextMedia], ) -> EmbeddingsResponse: raise NotImplementedError() diff --git a/llama_stack/providers/inline/inference/vllm/vllm.py b/llama_stack/providers/inline/inference/vllm/vllm.py index 3b1a0dd50..8869cc07f 100644 --- a/llama_stack/providers/inline/inference/vllm/vllm.py +++ b/llama_stack/providers/inline/inference/vllm/vllm.py @@ -110,7 +110,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate): async def completion( self, - model: str, + model_id: str, content: InterleavedTextMedia, sampling_params: Optional[SamplingParams] = SamplingParams(), response_format: Optional[ResponseFormat] = None, @@ -120,7 +120,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate): log.info("vLLM completion") messages = [UserMessage(content=content)] return self.chat_completion( - model=model, + model=model_id, messages=messages, sampling_params=sampling_params, stream=stream, @@ -129,7 +129,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate): async def chat_completion( self, - model: str, + model_id: str, messages: List[Message], sampling_params: Optional[SamplingParams] = SamplingParams(), tools: Optional[List[ToolDefinition]] = None, @@ -144,7 +144,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate): assert self.engine is not None request = ChatCompletionRequest( - model=model, + model=model_id, messages=messages, sampling_params=sampling_params, tools=tools or [], @@ -215,7 +215,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate): yield chunk async def embeddings( - self, model: str, contents: list[InterleavedTextMedia] + self, model_id: str, contents: list[InterleavedTextMedia] ) -> EmbeddingsResponse: log.info("vLLM embeddings") # TODO diff --git a/llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/llm_as_judge_scoring_fn.py b/llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/llm_as_judge_scoring_fn.py index a950f35f9..4b43de93f 100644 --- a/llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/llm_as_judge_scoring_fn.py +++ b/llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/llm_as_judge_scoring_fn.py @@ -62,7 +62,7 @@ class LlmAsJudgeScoringFn(BaseScoringFn): ) judge_response = await self.inference_api.chat_completion( - model=fn_def.params.judge_model, + model_id=fn_def.params.judge_model, messages=[ { "role": "user", diff --git a/llama_stack/providers/remote/inference/bedrock/bedrock.py b/llama_stack/providers/remote/inference/bedrock/bedrock.py index d9f82c611..f575d9dc3 100644 --- a/llama_stack/providers/remote/inference/bedrock/bedrock.py +++ b/llama_stack/providers/remote/inference/bedrock/bedrock.py @@ -7,11 +7,15 @@ from typing import * # noqa: F403 from botocore.client import BaseClient +from llama_models.datatypes import CoreModelId from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.tokenizer import Tokenizer -from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper +from llama_stack.providers.utils.inference.model_registry import ( + build_model_alias, + ModelRegistryHelper, +) from llama_stack.apis.inference import * # noqa: F403 @@ -19,19 +23,26 @@ from llama_stack.providers.remote.inference.bedrock.config import BedrockConfig from llama_stack.providers.utils.bedrock.client import create_bedrock_client -BEDROCK_SUPPORTED_MODELS = { - "Llama3.1-8B-Instruct": "meta.llama3-1-8b-instruct-v1:0", - "Llama3.1-70B-Instruct": "meta.llama3-1-70b-instruct-v1:0", - "Llama3.1-405B-Instruct": "meta.llama3-1-405b-instruct-v1:0", -} +model_aliases = [ + build_model_alias( + "meta.llama3-1-8b-instruct-v1:0", + CoreModelId.llama3_1_8b_instruct.value, + ), + build_model_alias( + "meta.llama3-1-70b-instruct-v1:0", + CoreModelId.llama3_1_70b_instruct.value, + ), + build_model_alias( + "meta.llama3-1-405b-instruct-v1:0", + CoreModelId.llama3_1_405b_instruct.value, + ), +] # NOTE: this is not quite tested after the recent refactors class BedrockInferenceAdapter(ModelRegistryHelper, Inference): def __init__(self, config: BedrockConfig) -> None: - ModelRegistryHelper.__init__( - self, stack_to_provider_models_map=BEDROCK_SUPPORTED_MODELS - ) + ModelRegistryHelper.__init__(self, model_aliases) self._config = config self._client = create_bedrock_client(config) @@ -49,7 +60,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference): async def completion( self, - model: str, + model_id: str, content: InterleavedTextMedia, sampling_params: Optional[SamplingParams] = SamplingParams(), response_format: Optional[ResponseFormat] = None, @@ -286,7 +297,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference): async def chat_completion( self, - model: str, + model_id: str, messages: List[Message], sampling_params: Optional[SamplingParams] = SamplingParams(), response_format: Optional[ResponseFormat] = None, @@ -298,8 +309,9 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference): ) -> Union[ ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk] ]: + model = await self.model_store.get_model(model_id) request = ChatCompletionRequest( - model=model, + model=model.provider_resource_id, messages=messages, sampling_params=sampling_params, tools=tools or [], @@ -404,7 +416,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference): pass def _get_params_for_chat_completion(self, request: ChatCompletionRequest) -> Dict: - bedrock_model = self.map_to_provider_model(request.model) + bedrock_model = request.model inference_config = BedrockInferenceAdapter.get_bedrock_inference_config( request.sampling_params ) @@ -433,7 +445,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference): async def embeddings( self, - model: str, + model_id: str, contents: List[InterleavedTextMedia], ) -> EmbeddingsResponse: raise NotImplementedError() diff --git a/llama_stack/providers/remote/inference/databricks/databricks.py b/llama_stack/providers/remote/inference/databricks/databricks.py index f12ecb7f5..0ebb625bc 100644 --- a/llama_stack/providers/remote/inference/databricks/databricks.py +++ b/llama_stack/providers/remote/inference/databricks/databricks.py @@ -6,6 +6,8 @@ from typing import AsyncGenerator +from llama_models.datatypes import CoreModelId + from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.datatypes import Message @@ -15,7 +17,10 @@ from openai import OpenAI from llama_stack.apis.inference import * # noqa: F403 -from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper +from llama_stack.providers.utils.inference.model_registry import ( + build_model_alias, + ModelRegistryHelper, +) from llama_stack.providers.utils.inference.openai_compat import ( get_sampling_options, process_chat_completion_response, @@ -28,16 +33,23 @@ from llama_stack.providers.utils.inference.prompt_adapter import ( from .config import DatabricksImplConfig -DATABRICKS_SUPPORTED_MODELS = { - "Llama3.1-70B-Instruct": "databricks-meta-llama-3-1-70b-instruct", - "Llama3.1-405B-Instruct": "databricks-meta-llama-3-1-405b-instruct", -} +model_aliases = [ + build_model_alias( + "databricks-meta-llama-3-1-70b-instruct", + CoreModelId.llama3_1_70b_instruct.value, + ), + build_model_alias( + "databricks-meta-llama-3-1-405b-instruct", + CoreModelId.llama3_1_405b_instruct.value, + ), +] class DatabricksInferenceAdapter(ModelRegistryHelper, Inference): def __init__(self, config: DatabricksImplConfig) -> None: ModelRegistryHelper.__init__( - self, stack_to_provider_models_map=DATABRICKS_SUPPORTED_MODELS + self, + model_aliases=model_aliases, ) self.config = config self.formatter = ChatFormat(Tokenizer.get_instance()) @@ -113,8 +125,10 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference): def _get_params(self, request: ChatCompletionRequest) -> dict: return { - "model": self.map_to_provider_model(request.model), - "prompt": chat_completion_request_to_prompt(request, self.formatter), + "model": request.model, + "prompt": chat_completion_request_to_prompt( + request, self.get_llama_model(request.model), self.formatter + ), "stream": request.stream, **get_sampling_options(request.sampling_params), } diff --git a/llama_stack/providers/remote/inference/fireworks/fireworks.py b/llama_stack/providers/remote/inference/fireworks/fireworks.py index 57e851c5b..42075eff7 100644 --- a/llama_stack/providers/remote/inference/fireworks/fireworks.py +++ b/llama_stack/providers/remote/inference/fireworks/fireworks.py @@ -7,14 +7,17 @@ from typing import AsyncGenerator from fireworks.client import Fireworks +from llama_models.datatypes import CoreModelId from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.datatypes import Message from llama_models.llama3.api.tokenizer import Tokenizer - from llama_stack.apis.inference import * # noqa: F403 from llama_stack.distribution.request_headers import NeedsRequestProviderData -from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper +from llama_stack.providers.utils.inference.model_registry import ( + build_model_alias, + ModelRegistryHelper, +) from llama_stack.providers.utils.inference.openai_compat import ( get_sampling_options, process_chat_completion_response, @@ -31,25 +34,52 @@ from llama_stack.providers.utils.inference.prompt_adapter import ( from .config import FireworksImplConfig -FIREWORKS_SUPPORTED_MODELS = { - "Llama3.1-8B-Instruct": "fireworks/llama-v3p1-8b-instruct", - "Llama3.1-70B-Instruct": "fireworks/llama-v3p1-70b-instruct", - "Llama3.1-405B-Instruct": "fireworks/llama-v3p1-405b-instruct", - "Llama3.2-1B-Instruct": "fireworks/llama-v3p2-1b-instruct", - "Llama3.2-3B-Instruct": "fireworks/llama-v3p2-3b-instruct", - "Llama3.2-11B-Vision-Instruct": "fireworks/llama-v3p2-11b-vision-instruct", - "Llama3.2-90B-Vision-Instruct": "fireworks/llama-v3p2-90b-vision-instruct", - "Llama-Guard-3-8B": "fireworks/llama-guard-3-8b", -} + +model_aliases = [ + build_model_alias( + "fireworks/llama-v3p1-8b-instruct", + CoreModelId.llama3_1_8b_instruct.value, + ), + build_model_alias( + "fireworks/llama-v3p1-70b-instruct", + CoreModelId.llama3_1_70b_instruct.value, + ), + build_model_alias( + "fireworks/llama-v3p1-405b-instruct", + CoreModelId.llama3_1_405b_instruct.value, + ), + build_model_alias( + "fireworks/llama-v3p2-1b-instruct", + CoreModelId.llama3_2_3b_instruct.value, + ), + build_model_alias( + "fireworks/llama-v3p2-3b-instruct", + CoreModelId.llama3_2_11b_vision_instruct.value, + ), + build_model_alias( + "fireworks/llama-v3p2-11b-vision-instruct", + CoreModelId.llama3_2_11b_vision_instruct.value, + ), + build_model_alias( + "fireworks/llama-v3p2-90b-vision-instruct", + CoreModelId.llama3_2_90b_vision_instruct.value, + ), + build_model_alias( + "fireworks/llama-guard-3-8b", + CoreModelId.llama_guard_3_8b.value, + ), + build_model_alias( + "fireworks/llama-guard-3-11b-vision", + CoreModelId.llama_guard_3_11b_vision.value, + ), +] class FireworksInferenceAdapter( ModelRegistryHelper, Inference, NeedsRequestProviderData ): def __init__(self, config: FireworksImplConfig) -> None: - ModelRegistryHelper.__init__( - self, stack_to_provider_models_map=FIREWORKS_SUPPORTED_MODELS - ) + ModelRegistryHelper.__init__(self, model_aliases) self.config = config self.formatter = ChatFormat(Tokenizer.get_instance()) @@ -74,15 +104,16 @@ class FireworksInferenceAdapter( async def completion( self, - model: str, + model_id: str, content: InterleavedTextMedia, sampling_params: Optional[SamplingParams] = SamplingParams(), response_format: Optional[ResponseFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> AsyncGenerator: + model = await self.model_store.get_model(model_id) request = CompletionRequest( - model=model, + model=model.provider_resource_id, content=content, sampling_params=sampling_params, response_format=response_format, @@ -138,7 +169,7 @@ class FireworksInferenceAdapter( async def chat_completion( self, - model: str, + model_id: str, messages: List[Message], sampling_params: Optional[SamplingParams] = SamplingParams(), tools: Optional[List[ToolDefinition]] = None, @@ -148,8 +179,9 @@ class FireworksInferenceAdapter( stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> AsyncGenerator: + model = await self.model_store.get_model(model_id) request = ChatCompletionRequest( - model=model, + model=model.provider_resource_id, messages=messages, sampling_params=sampling_params, tools=tools or [], @@ -207,7 +239,7 @@ class FireworksInferenceAdapter( ] else: input_dict["prompt"] = chat_completion_request_to_prompt( - request, self.formatter + request, self.get_llama_model(request.model), self.formatter ) else: assert ( @@ -221,7 +253,7 @@ class FireworksInferenceAdapter( input_dict["prompt"] = input_dict["prompt"][len("<|begin_of_text|>") :] return { - "model": self.map_to_provider_model(request.model), + "model": request.model, **input_dict, "stream": request.stream, **self._build_options(request.sampling_params, request.response_format), @@ -229,7 +261,7 @@ class FireworksInferenceAdapter( async def embeddings( self, - model: str, + model_id: str, contents: List[InterleavedTextMedia], ) -> EmbeddingsResponse: raise NotImplementedError() diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py index 938d05c08..99f74572e 100644 --- a/llama_stack/providers/remote/inference/ollama/ollama.py +++ b/llama_stack/providers/remote/inference/ollama/ollama.py @@ -7,15 +7,20 @@ from typing import AsyncGenerator import httpx +from llama_models.datatypes import CoreModelId from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.datatypes import Message from llama_models.llama3.api.tokenizer import Tokenizer - from ollama import AsyncClient +from llama_stack.providers.utils.inference.model_registry import ( + build_model_alias, + ModelRegistryHelper, +) + from llama_stack.apis.inference import * # noqa: F403 -from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate +from llama_stack.providers.datatypes import ModelsProtocolPrivate from llama_stack.providers.utils.inference.openai_compat import ( get_sampling_options, @@ -33,19 +38,45 @@ from llama_stack.providers.utils.inference.prompt_adapter import ( request_has_media, ) -OLLAMA_SUPPORTED_MODELS = { - "Llama3.1-8B-Instruct": "llama3.1:8b-instruct-fp16", - "Llama3.1-70B-Instruct": "llama3.1:70b-instruct-fp16", - "Llama3.2-1B-Instruct": "llama3.2:1b-instruct-fp16", - "Llama3.2-3B-Instruct": "llama3.2:3b-instruct-fp16", - "Llama-Guard-3-8B": "llama-guard3:8b", - "Llama-Guard-3-1B": "llama-guard3:1b", - "Llama3.2-11B-Vision-Instruct": "x/llama3.2-vision:11b-instruct-fp16", -} + +model_aliases = [ + build_model_alias( + "llama3.1:8b-instruct-fp16", + CoreModelId.llama3_1_8b_instruct.value, + ), + build_model_alias( + "llama3.1:70b-instruct-fp16", + CoreModelId.llama3_1_70b_instruct.value, + ), + build_model_alias( + "llama3.2:1b-instruct-fp16", + CoreModelId.llama3_2_1b_instruct.value, + ), + build_model_alias( + "llama3.2:3b-instruct-fp16", + CoreModelId.llama3_2_3b_instruct.value, + ), + build_model_alias( + "llama-guard3:8b", + CoreModelId.llama_guard_3_8b.value, + ), + build_model_alias( + "llama-guard3:1b", + CoreModelId.llama_guard_3_1b.value, + ), + build_model_alias( + "x/llama3.2-vision:11b-instruct-fp16", + CoreModelId.llama3_2_11b_vision_instruct.value, + ), +] -class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): +class OllamaInferenceAdapter(Inference, ModelRegistryHelper, ModelsProtocolPrivate): def __init__(self, url: str) -> None: + ModelRegistryHelper.__init__( + self, + model_aliases=model_aliases, + ) self.url = url self.formatter = ChatFormat(Tokenizer.get_instance()) @@ -65,44 +96,18 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): async def shutdown(self) -> None: pass - async def register_model(self, model: Model) -> None: - if model.identifier not in OLLAMA_SUPPORTED_MODELS: - raise ValueError(f"Model {model.identifier} is not supported by Ollama") - - async def list_models(self) -> List[Model]: - ollama_to_llama = {v: k for k, v in OLLAMA_SUPPORTED_MODELS.items()} - - ret = [] - res = await self.client.ps() - for r in res["models"]: - if r["model"] not in ollama_to_llama: - print(f"Ollama is running a model unknown to Llama Stack: {r['model']}") - continue - - llama_model = ollama_to_llama[r["model"]] - print(f"Found model {llama_model} in Ollama") - ret.append( - Model( - identifier=llama_model, - metadata={ - "ollama_model": r["model"], - }, - ) - ) - - return ret - async def completion( self, - model: str, + model_id: str, content: InterleavedTextMedia, sampling_params: Optional[SamplingParams] = SamplingParams(), response_format: Optional[ResponseFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> AsyncGenerator: + model = await self.model_store.get_model(model_id) request = CompletionRequest( - model=model, + model=model.provider_resource_id, content=content, sampling_params=sampling_params, stream=stream, @@ -148,7 +153,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): async def chat_completion( self, - model: str, + model_id: str, messages: List[Message], sampling_params: Optional[SamplingParams] = SamplingParams(), response_format: Optional[ResponseFormat] = None, @@ -158,8 +163,10 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> AsyncGenerator: + model = await self.model_store.get_model(model_id) + print(f"model={model}") request = ChatCompletionRequest( - model=model, + model=model.provider_resource_id, messages=messages, sampling_params=sampling_params, tools=tools or [], @@ -197,7 +204,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): else: input_dict["raw"] = True input_dict["prompt"] = chat_completion_request_to_prompt( - request, self.formatter + request, self.get_llama_model(request.model), self.formatter ) else: assert ( @@ -207,7 +214,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): input_dict["raw"] = True return { - "model": OLLAMA_SUPPORTED_MODELS[request.model], + "model": request.model, **input_dict, "options": sampling_options, "stream": request.stream, @@ -271,7 +278,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): async def embeddings( self, - model: str, + model_id: str, contents: List[InterleavedTextMedia], ) -> EmbeddingsResponse: raise NotImplementedError() diff --git a/llama_stack/providers/remote/inference/together/together.py b/llama_stack/providers/remote/inference/together/together.py index 28a566415..aae34bb87 100644 --- a/llama_stack/providers/remote/inference/together/together.py +++ b/llama_stack/providers/remote/inference/together/together.py @@ -6,6 +6,8 @@ from typing import AsyncGenerator +from llama_models.datatypes import CoreModelId + from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.datatypes import Message @@ -15,7 +17,10 @@ from together import Together from llama_stack.apis.inference import * # noqa: F403 from llama_stack.distribution.request_headers import NeedsRequestProviderData -from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper +from llama_stack.providers.utils.inference.model_registry import ( + build_model_alias, + ModelRegistryHelper, +) from llama_stack.providers.utils.inference.openai_compat import ( get_sampling_options, process_chat_completion_response, @@ -33,25 +38,47 @@ from llama_stack.providers.utils.inference.prompt_adapter import ( from .config import TogetherImplConfig -TOGETHER_SUPPORTED_MODELS = { - "Llama3.1-8B-Instruct": "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo", - "Llama3.1-70B-Instruct": "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo", - "Llama3.1-405B-Instruct": "meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo", - "Llama3.2-3B-Instruct": "meta-llama/Llama-3.2-3B-Instruct-Turbo", - "Llama3.2-11B-Vision-Instruct": "meta-llama/Llama-3.2-11B-Vision-Instruct-Turbo", - "Llama3.2-90B-Vision-Instruct": "meta-llama/Llama-3.2-90B-Vision-Instruct-Turbo", - "Llama-Guard-3-8B": "meta-llama/Meta-Llama-Guard-3-8B", - "Llama-Guard-3-11B-Vision": "meta-llama/Llama-Guard-3-11B-Vision-Turbo", -} +model_aliases = [ + build_model_alias( + "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo", + CoreModelId.llama3_1_8b_instruct.value, + ), + build_model_alias( + "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo", + CoreModelId.llama3_1_70b_instruct.value, + ), + build_model_alias( + "meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo", + CoreModelId.llama3_1_405b_instruct.value, + ), + build_model_alias( + "meta-llama/Llama-3.2-3B-Instruct-Turbo", + CoreModelId.llama3_2_3b_instruct.value, + ), + build_model_alias( + "meta-llama/Llama-3.2-11B-Vision-Instruct-Turbo", + CoreModelId.llama3_2_11b_vision_instruct.value, + ), + build_model_alias( + "meta-llama/Llama-3.2-90B-Vision-Instruct-Turbo", + CoreModelId.llama3_2_90b_vision_instruct.value, + ), + build_model_alias( + "meta-llama/Meta-Llama-Guard-3-8B", + CoreModelId.llama_guard_3_8b.value, + ), + build_model_alias( + "meta-llama/Llama-Guard-3-11B-Vision-Turbo", + CoreModelId.llama_guard_3_11b_vision.value, + ), +] class TogetherInferenceAdapter( ModelRegistryHelper, Inference, NeedsRequestProviderData ): def __init__(self, config: TogetherImplConfig) -> None: - ModelRegistryHelper.__init__( - self, stack_to_provider_models_map=TOGETHER_SUPPORTED_MODELS - ) + ModelRegistryHelper.__init__(self, model_aliases) self.config = config self.formatter = ChatFormat(Tokenizer.get_instance()) @@ -63,15 +90,16 @@ class TogetherInferenceAdapter( async def completion( self, - model: str, + model_id: str, content: InterleavedTextMedia, sampling_params: Optional[SamplingParams] = SamplingParams(), response_format: Optional[ResponseFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> AsyncGenerator: + model = await self.model_store.get_model(model_id) request = CompletionRequest( - model=model, + model=model.provider_resource_id, content=content, sampling_params=sampling_params, response_format=response_format, @@ -135,7 +163,7 @@ class TogetherInferenceAdapter( async def chat_completion( self, - model: str, + model_id: str, messages: List[Message], sampling_params: Optional[SamplingParams] = SamplingParams(), tools: Optional[List[ToolDefinition]] = None, @@ -145,8 +173,9 @@ class TogetherInferenceAdapter( stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> AsyncGenerator: + model = await self.model_store.get_model(model_id) request = ChatCompletionRequest( - model=model, + model=model.provider_resource_id, messages=messages, sampling_params=sampling_params, tools=tools or [], @@ -204,7 +233,7 @@ class TogetherInferenceAdapter( ] else: input_dict["prompt"] = chat_completion_request_to_prompt( - request, self.formatter + request, self.get_llama_model(request.model), self.formatter ) else: assert ( @@ -213,7 +242,7 @@ class TogetherInferenceAdapter( input_dict["prompt"] = completion_request_to_prompt(request, self.formatter) return { - "model": self.map_to_provider_model(request.model), + "model": request.model, **input_dict, "stream": request.stream, **self._build_options(request.sampling_params, request.response_format), @@ -221,7 +250,7 @@ class TogetherInferenceAdapter( async def embeddings( self, - model: str, + model_id: str, contents: List[InterleavedTextMedia], ) -> EmbeddingsResponse: raise NotImplementedError() diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index bd7f5073c..e5eb6e1ea 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -8,13 +8,17 @@ from typing import AsyncGenerator from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.datatypes import Message from llama_models.llama3.api.tokenizer import Tokenizer -from llama_models.sku_list import all_registered_models, resolve_model +from llama_models.sku_list import all_registered_models from openai import OpenAI from llama_stack.apis.inference import * # noqa: F403 -from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate +from llama_stack.providers.datatypes import ModelsProtocolPrivate +from llama_stack.providers.utils.inference.model_registry import ( + build_model_alias, + ModelRegistryHelper, +) from llama_stack.providers.utils.inference.openai_compat import ( get_sampling_options, process_chat_completion_response, @@ -30,44 +34,36 @@ from llama_stack.providers.utils.inference.prompt_adapter import ( from .config import VLLMInferenceAdapterConfig -class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): +def build_model_aliases(): + return [ + build_model_alias( + model.huggingface_repo, + model.descriptor(), + ) + for model in all_registered_models() + if model.huggingface_repo + ] + + +class VLLMInferenceAdapter(Inference, ModelRegistryHelper, ModelsProtocolPrivate): def __init__(self, config: VLLMInferenceAdapterConfig) -> None: + ModelRegistryHelper.__init__( + self, + model_aliases=build_model_aliases(), + ) self.config = config self.formatter = ChatFormat(Tokenizer.get_instance()) self.client = None - self.huggingface_repo_to_llama_model_id = { - model.huggingface_repo: model.descriptor() - for model in all_registered_models() - if model.huggingface_repo - } async def initialize(self) -> None: self.client = OpenAI(base_url=self.config.url, api_key=self.config.api_token) - async def register_model(self, model: Model) -> None: - for running_model in self.client.models.list(): - repo = running_model.id - if repo not in self.huggingface_repo_to_llama_model_id: - print(f"Unknown model served by vllm: {repo}") - continue - - identifier = self.huggingface_repo_to_llama_model_id[repo] - if identifier == model.provider_resource_id: - print( - f"Verified that model {model.provider_resource_id} is being served by vLLM" - ) - return - - raise ValueError( - f"Model {model.provider_resource_id} is not being served by vLLM" - ) - async def shutdown(self) -> None: pass async def completion( self, - model: str, + model_id: str, content: InterleavedTextMedia, sampling_params: Optional[SamplingParams] = SamplingParams(), response_format: Optional[ResponseFormat] = None, @@ -78,7 +74,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): async def chat_completion( self, - model: str, + model_id: str, messages: List[Message], sampling_params: Optional[SamplingParams] = SamplingParams(), response_format: Optional[ResponseFormat] = None, @@ -88,8 +84,9 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> AsyncGenerator: + model = await self.model_store.get_model(model_id) request = ChatCompletionRequest( - model=model, + model=model.provider_resource_id, messages=messages, sampling_params=sampling_params, tools=tools or [], @@ -141,10 +138,6 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): if "max_tokens" not in options: options["max_tokens"] = self.config.max_tokens - model = resolve_model(request.model) - if model is None: - raise ValueError(f"Unknown model: {request.model}") - input_dict = {} media_present = request_has_media(request) if isinstance(request, ChatCompletionRequest): @@ -156,16 +149,20 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): ] else: input_dict["prompt"] = chat_completion_request_to_prompt( - request, self.formatter + request, self.get_llama_model(request.model), self.formatter ) else: assert ( not media_present ), "Together does not support media for Completion requests" - input_dict["prompt"] = completion_request_to_prompt(request, self.formatter) + input_dict["prompt"] = completion_request_to_prompt( + request, + self.get_llama_model(request.model), + self.formatter, + ) return { - "model": model.huggingface_repo, + "model": request.model, **input_dict, "stream": request.stream, **options, @@ -173,7 +170,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): async def embeddings( self, - model: str, + model_id: str, contents: List[InterleavedTextMedia], ) -> EmbeddingsResponse: raise NotImplementedError() diff --git a/llama_stack/providers/tests/inference/fixtures.py b/llama_stack/providers/tests/inference/fixtures.py index d35ebab28..f6f2a30e8 100644 --- a/llama_stack/providers/tests/inference/fixtures.py +++ b/llama_stack/providers/tests/inference/fixtures.py @@ -49,7 +49,7 @@ def inference_meta_reference(inference_model) -> ProviderFixture: providers=[ Provider( provider_id=f"meta-reference-{i}", - provider_type="meta-reference", + provider_type="inline::meta-reference", config=MetaReferenceInferenceConfig( model=m, max_seq_len=4096, @@ -142,6 +142,31 @@ def inference_bedrock() -> ProviderFixture: ) +def get_model_short_name(model_name: str) -> str: + """Convert model name to a short test identifier. + + Args: + model_name: Full model name like "Llama3.1-8B-Instruct" + + Returns: + Short name like "llama_8b" suitable for test markers + """ + model_name = model_name.lower() + if "vision" in model_name: + return "llama_vision" + elif "3b" in model_name: + return "llama_3b" + elif "8b" in model_name: + return "llama_8b" + else: + return model_name.replace(".", "_").replace("-", "_") + + +@pytest.fixture(scope="session") +def model_id(inference_model) -> str: + return get_model_short_name(inference_model) + + INFERENCE_FIXTURES = [ "meta_reference", "ollama", diff --git a/llama_stack/providers/tests/inference/test_text_inference.py b/llama_stack/providers/tests/inference/test_text_inference.py index e7bfbc135..70047a61f 100644 --- a/llama_stack/providers/tests/inference/test_text_inference.py +++ b/llama_stack/providers/tests/inference/test_text_inference.py @@ -96,7 +96,7 @@ class TestInference: response = await inference_impl.completion( content="Micheael Jordan is born in ", stream=False, - model=inference_model, + model_id=inference_model, sampling_params=SamplingParams( max_tokens=50, ), @@ -110,7 +110,7 @@ class TestInference: async for r in await inference_impl.completion( content="Roses are red,", stream=True, - model=inference_model, + model_id=inference_model, sampling_params=SamplingParams( max_tokens=50, ), @@ -171,7 +171,7 @@ class TestInference: ): inference_impl, _ = inference_stack response = await inference_impl.chat_completion( - model=inference_model, + model_id=inference_model, messages=sample_messages, stream=False, **common_params, @@ -204,7 +204,7 @@ class TestInference: num_seasons_in_nba: int response = await inference_impl.chat_completion( - model=inference_model, + model_id=inference_model, messages=[ SystemMessage(content="You are a helpful assistant."), UserMessage(content="Please give me information about Michael Jordan."), @@ -227,7 +227,7 @@ class TestInference: assert answer.num_seasons_in_nba == 15 response = await inference_impl.chat_completion( - model=inference_model, + model_id=inference_model, messages=[ SystemMessage(content="You are a helpful assistant."), UserMessage(content="Please give me information about Michael Jordan."), @@ -250,7 +250,7 @@ class TestInference: response = [ r async for r in await inference_impl.chat_completion( - model=inference_model, + model_id=inference_model, messages=sample_messages, stream=True, **common_params, @@ -286,7 +286,7 @@ class TestInference: ] response = await inference_impl.chat_completion( - model=inference_model, + model_id=inference_model, messages=messages, tools=[sample_tool_definition], stream=False, @@ -327,7 +327,7 @@ class TestInference: response = [ r async for r in await inference_impl.chat_completion( - model=inference_model, + model_id=inference_model, messages=messages, tools=[sample_tool_definition], stream=True, diff --git a/llama_stack/providers/utils/inference/model_registry.py b/llama_stack/providers/utils/inference/model_registry.py index 141e4af31..7120e9e97 100644 --- a/llama_stack/providers/utils/inference/model_registry.py +++ b/llama_stack/providers/utils/inference/model_registry.py @@ -4,32 +4,61 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Dict +from collections import namedtuple +from typing import List, Optional -from llama_models.sku_list import resolve_model +from llama_models.sku_list import all_registered_models from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate +ModelAlias = namedtuple("ModelAlias", ["provider_model_id", "aliases", "llama_model"]) + + +def get_huggingface_repo(model_descriptor: str) -> Optional[str]: + for model in all_registered_models(): + if model.descriptor() == model_descriptor: + return model.huggingface_repo + return None + + +def build_model_alias(provider_model_id: str, model_descriptor: str) -> ModelAlias: + return ModelAlias( + provider_model_id=provider_model_id, + aliases=[ + model_descriptor, + get_huggingface_repo(model_descriptor), + ], + llama_model=model_descriptor, + ) + class ModelRegistryHelper(ModelsProtocolPrivate): + def __init__(self, model_aliases: List[ModelAlias]): + self.alias_to_provider_id_map = {} + self.provider_id_to_llama_model_map = {} + for alias_obj in model_aliases: + for alias in alias_obj.aliases: + self.alias_to_provider_id_map[alias] = alias_obj.provider_model_id + # also add a mapping from provider model id to itself for easy lookup + self.alias_to_provider_id_map[alias_obj.provider_model_id] = ( + alias_obj.provider_model_id + ) + self.provider_id_to_llama_model_map[alias_obj.provider_model_id] = ( + alias_obj.llama_model + ) - def __init__(self, stack_to_provider_models_map: Dict[str, str]): - self.stack_to_provider_models_map = stack_to_provider_models_map - - def map_to_provider_model(self, identifier: str) -> str: - model = resolve_model(identifier) - if not model: + def get_provider_model_id(self, identifier: str) -> str: + if identifier in self.alias_to_provider_id_map: + return self.alias_to_provider_id_map[identifier] + else: raise ValueError(f"Unknown model: `{identifier}`") - if identifier not in self.stack_to_provider_models_map: - raise ValueError( - f"Model {identifier} not found in map {self.stack_to_provider_models_map}" - ) + def get_llama_model(self, provider_model_id: str) -> str: + return self.provider_id_to_llama_model_map[provider_model_id] - return self.stack_to_provider_models_map[identifier] + async def register_model(self, model: Model) -> Model: + model.provider_resource_id = self.get_provider_model_id( + model.provider_resource_id + ) - async def register_model(self, model: Model) -> None: - if model.identifier not in self.stack_to_provider_models_map: - raise ValueError( - f"Unsupported model {model.identifier}. Supported models: {self.stack_to_provider_models_map.keys()}" - ) + return model diff --git a/llama_stack/providers/utils/inference/prompt_adapter.py b/llama_stack/providers/utils/inference/prompt_adapter.py index 45e43c898..2df04664f 100644 --- a/llama_stack/providers/utils/inference/prompt_adapter.py +++ b/llama_stack/providers/utils/inference/prompt_adapter.py @@ -147,17 +147,17 @@ def augment_content_with_response_format_prompt(response_format, content): def chat_completion_request_to_prompt( - request: ChatCompletionRequest, formatter: ChatFormat + request: ChatCompletionRequest, llama_model: str, formatter: ChatFormat ) -> str: - messages = chat_completion_request_to_messages(request) + messages = chat_completion_request_to_messages(request, llama_model) model_input = formatter.encode_dialog_prompt(messages) return formatter.tokenizer.decode(model_input.tokens) def chat_completion_request_to_model_input_info( - request: ChatCompletionRequest, formatter: ChatFormat + request: ChatCompletionRequest, llama_model: str, formatter: ChatFormat ) -> Tuple[str, int]: - messages = chat_completion_request_to_messages(request) + messages = chat_completion_request_to_messages(request, llama_model) model_input = formatter.encode_dialog_prompt(messages) return ( formatter.tokenizer.decode(model_input.tokens), @@ -167,14 +167,15 @@ def chat_completion_request_to_model_input_info( def chat_completion_request_to_messages( request: ChatCompletionRequest, + llama_model: str, ) -> List[Message]: """Reads chat completion request and augments the messages to handle tools. For eg. for llama_3_1, add system message with the appropriate tools or add user messsage for custom tools, etc. """ - model = resolve_model(request.model) + model = resolve_model(llama_model) if model is None: - cprint(f"Could not resolve model {request.model}", color="red") + cprint(f"Could not resolve model {llama_model}", color="red") return request.messages if model.descriptor() not in supported_inference_models():