diff --git a/docs/resources/llama-stack-spec.html b/docs/resources/llama-stack-spec.html
index 7ef4ece21..f87cb5590 100644
--- a/docs/resources/llama-stack-spec.html
+++ b/docs/resources/llama-stack-spec.html
@@ -21,7 +21,7 @@
"info": {
"title": "[DRAFT] Llama Stack Specification",
"version": "0.0.1",
- "description": "This is the specification of the llama stack that provides\n a set of endpoints and their corresponding interfaces that are tailored to\n best leverage Llama Models. The specification is still in draft and subject to change.\n Generated at 2024-11-12 11:39:48.665782"
+ "description": "This is the specification of the llama stack that provides\n a set of endpoints and their corresponding interfaces that are tailored to\n best leverage Llama Models. The specification is still in draft and subject to change.\n Generated at 2024-11-12 15:47:15.607543"
},
"servers": [
{
@@ -2856,7 +2856,7 @@
"ChatCompletionRequest": {
"type": "object",
"properties": {
- "model": {
+ "model_id": {
"type": "string"
},
"messages": {
@@ -2993,7 +2993,7 @@
},
"additionalProperties": false,
"required": [
- "model",
+ "model_id",
"messages"
]
},
@@ -3120,7 +3120,7 @@
"CompletionRequest": {
"type": "object",
"properties": {
- "model": {
+ "model_id": {
"type": "string"
},
"content": {
@@ -3249,7 +3249,7 @@
},
"additionalProperties": false,
"required": [
- "model",
+ "model_id",
"content"
]
},
@@ -4552,7 +4552,7 @@
"EmbeddingsRequest": {
"type": "object",
"properties": {
- "model": {
+ "model_id": {
"type": "string"
},
"contents": {
@@ -4584,7 +4584,7 @@
},
"additionalProperties": false,
"required": [
- "model",
+ "model_id",
"contents"
]
},
@@ -7837,34 +7837,10 @@
],
"tags": [
{
- "name": "MemoryBanks"
+ "name": "Safety"
},
{
- "name": "BatchInference"
- },
- {
- "name": "Agents"
- },
- {
- "name": "Inference"
- },
- {
- "name": "DatasetIO"
- },
- {
- "name": "Eval"
- },
- {
- "name": "Models"
- },
- {
- "name": "PostTraining"
- },
- {
- "name": "ScoringFunctions"
- },
- {
- "name": "Datasets"
+ "name": "EvalTasks"
},
{
"name": "Shields"
@@ -7872,15 +7848,6 @@
{
"name": "Telemetry"
},
- {
- "name": "Inspect"
- },
- {
- "name": "Safety"
- },
- {
- "name": "SyntheticDataGeneration"
- },
{
"name": "Memory"
},
@@ -7888,7 +7855,40 @@
"name": "Scoring"
},
{
- "name": "EvalTasks"
+ "name": "ScoringFunctions"
+ },
+ {
+ "name": "SyntheticDataGeneration"
+ },
+ {
+ "name": "Models"
+ },
+ {
+ "name": "Agents"
+ },
+ {
+ "name": "MemoryBanks"
+ },
+ {
+ "name": "DatasetIO"
+ },
+ {
+ "name": "Inference"
+ },
+ {
+ "name": "Datasets"
+ },
+ {
+ "name": "PostTraining"
+ },
+ {
+ "name": "BatchInference"
+ },
+ {
+ "name": "Eval"
+ },
+ {
+ "name": "Inspect"
},
{
"name": "BuiltinTool",
diff --git a/docs/resources/llama-stack-spec.yaml b/docs/resources/llama-stack-spec.yaml
index b86c0df61..87268ff47 100644
--- a/docs/resources/llama-stack-spec.yaml
+++ b/docs/resources/llama-stack-spec.yaml
@@ -396,7 +396,7 @@ components:
- $ref: '#/components/schemas/ToolResponseMessage'
- $ref: '#/components/schemas/CompletionMessage'
type: array
- model:
+ model_id:
type: string
response_format:
oneOf:
@@ -453,7 +453,7 @@ components:
$ref: '#/components/schemas/ToolDefinition'
type: array
required:
- - model
+ - model_id
- messages
type: object
ChatCompletionResponse:
@@ -577,7 +577,7 @@ components:
default: 0
type: integer
type: object
- model:
+ model_id:
type: string
response_format:
oneOf:
@@ -626,7 +626,7 @@ components:
stream:
type: boolean
required:
- - model
+ - model_id
- content
type: object
CompletionResponse:
@@ -903,10 +903,10 @@ components:
- $ref: '#/components/schemas/ImageMedia'
type: array
type: array
- model:
+ model_id:
type: string
required:
- - model
+ - model_id
- contents
type: object
EmbeddingsResponse:
@@ -3384,7 +3384,7 @@ info:
description: "This is the specification of the llama stack that provides\n \
\ a set of endpoints and their corresponding interfaces that are tailored\
\ to\n best leverage Llama Models. The specification is still in\
- \ draft and subject to change.\n Generated at 2024-11-12 11:39:48.665782"
+ \ draft and subject to change.\n Generated at 2024-11-12 15:47:15.607543"
title: '[DRAFT] Llama Stack Specification'
version: 0.0.1
jsonSchemaDialect: https://json-schema.org/draft/2020-12/schema
@@ -4748,24 +4748,24 @@ security:
servers:
- url: http://any-hosted-llama-stack.com
tags:
-- name: MemoryBanks
-- name: BatchInference
-- name: Agents
-- name: Inference
-- name: DatasetIO
-- name: Eval
-- name: Models
-- name: PostTraining
-- name: ScoringFunctions
-- name: Datasets
+- name: Safety
+- name: EvalTasks
- name: Shields
- name: Telemetry
-- name: Inspect
-- name: Safety
-- name: SyntheticDataGeneration
- name: Memory
- name: Scoring
-- name: EvalTasks
+- name: ScoringFunctions
+- name: SyntheticDataGeneration
+- name: Models
+- name: Agents
+- name: MemoryBanks
+- name: DatasetIO
+- name: Inference
+- name: Datasets
+- name: PostTraining
+- name: BatchInference
+- name: Eval
+- name: Inspect
- description:
name: BuiltinTool
- description: EmbeddingsResponse: ...
diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py
index 220dfdb56..5a62b6d64 100644
--- a/llama_stack/distribution/routers/routers.py
+++ b/llama_stack/distribution/routers/routers.py
@@ -95,7 +95,7 @@ class InferenceRouter(Inference):
async def chat_completion(
self,
- model: str,
+ model_id: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
@@ -106,7 +106,7 @@ class InferenceRouter(Inference):
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
params = dict(
- model=model,
+ model_id=model_id,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
@@ -116,7 +116,7 @@ class InferenceRouter(Inference):
stream=stream,
logprobs=logprobs,
)
- provider = self.routing_table.get_provider_impl(model)
+ provider = self.routing_table.get_provider_impl(model_id)
if stream:
return (chunk async for chunk in await provider.chat_completion(**params))
else:
@@ -124,16 +124,16 @@ class InferenceRouter(Inference):
async def completion(
self,
- model: str,
+ model_id: str,
content: InterleavedTextMedia,
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
- provider = self.routing_table.get_provider_impl(model)
+ provider = self.routing_table.get_provider_impl(model_id)
params = dict(
- model=model,
+ model_id=model_id,
content=content,
sampling_params=sampling_params,
response_format=response_format,
@@ -147,11 +147,11 @@ class InferenceRouter(Inference):
async def embeddings(
self,
- model: str,
+ model_id: str,
contents: List[InterleavedTextMedia],
) -> EmbeddingsResponse:
- return await self.routing_table.get_provider_impl(model).embeddings(
- model=model,
+ return await self.routing_table.get_provider_impl(model_id).embeddings(
+ model_id=model_id,
contents=contents,
)
diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py
index d6fb5d662..249d3a144 100644
--- a/llama_stack/distribution/routers/routing_tables.py
+++ b/llama_stack/distribution/routers/routing_tables.py
@@ -28,7 +28,9 @@ def get_impl_api(p: Any) -> Api:
return p.__provider_spec__.api
-async def register_object_with_provider(obj: RoutableObject, p: Any) -> None:
+# TODO: this should return the registered object for all APIs
+async def register_object_with_provider(obj: RoutableObject, p: Any) -> RoutableObject:
+
api = get_impl_api(p)
if obj.provider_id == "remote":
@@ -42,7 +44,7 @@ async def register_object_with_provider(obj: RoutableObject, p: Any) -> None:
obj.provider_id = ""
if api == Api.inference:
- await p.register_model(obj)
+ return await p.register_model(obj)
elif api == Api.safety:
await p.register_shield(obj)
elif api == Api.memory:
@@ -167,7 +169,9 @@ class CommonRoutingTableImpl(RoutingTable):
assert len(objects) == 1
return objects[0]
- async def register_object(self, obj: RoutableObjectWithProvider):
+ async def register_object(
+ self, obj: RoutableObjectWithProvider
+ ) -> RoutableObjectWithProvider:
# Get existing objects from registry
existing_objects = await self.dist_registry.get(obj.type, obj.identifier)
@@ -177,7 +181,7 @@ class CommonRoutingTableImpl(RoutingTable):
print(
f"`{obj.identifier}` already registered with `{existing_obj.provider_id}`"
)
- return
+ return existing_obj
# if provider_id is not specified, pick an arbitrary one from existing entries
if not obj.provider_id and len(self.impls_by_provider_id) > 0:
@@ -188,8 +192,15 @@ class CommonRoutingTableImpl(RoutingTable):
p = self.impls_by_provider_id[obj.provider_id]
- await register_object_with_provider(obj, p)
- await self.dist_registry.register(obj)
+ registered_obj = await register_object_with_provider(obj, p)
+ # TODO: This needs to be fixed for all APIs once they return the registered object
+ if obj.type == ResourceType.model.value:
+ await self.dist_registry.register(registered_obj)
+ return registered_obj
+
+ else:
+ await self.dist_registry.register(obj)
+ return obj
async def get_all_with_type(self, type: str) -> List[RoutableObjectWithProvider]:
objs = await self.dist_registry.get_all()
@@ -228,8 +239,8 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
provider_id=provider_id,
metadata=metadata,
)
- await self.register_object(model)
- return model
+ registered_model = await self.register_object(model)
+ return registered_model
class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
diff --git a/llama_stack/providers/inline/eval/meta_reference/eval.py b/llama_stack/providers/inline/eval/meta_reference/eval.py
index ba2fc7c95..58241eb42 100644
--- a/llama_stack/providers/inline/eval/meta_reference/eval.py
+++ b/llama_stack/providers/inline/eval/meta_reference/eval.py
@@ -150,7 +150,7 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate):
messages.append(candidate.system_message)
messages += input_messages
response = await self.inference_api.chat_completion(
- model=candidate.model,
+ model_id=candidate.model,
messages=messages,
sampling_params=candidate.sampling_params,
)
diff --git a/llama_stack/providers/inline/inference/meta_reference/generation.py b/llama_stack/providers/inline/inference/meta_reference/generation.py
index 2f296c7c2..38c982473 100644
--- a/llama_stack/providers/inline/inference/meta_reference/generation.py
+++ b/llama_stack/providers/inline/inference/meta_reference/generation.py
@@ -86,6 +86,7 @@ class Llama:
and loads the pre-trained model and tokenizer.
"""
model = resolve_model(config.model)
+ llama_model = model.core_model_id.value
if not torch.distributed.is_initialized():
torch.distributed.init_process_group("nccl")
@@ -186,13 +187,20 @@ class Llama:
model.load_state_dict(state_dict, strict=False)
print(f"Loaded in {time.time() - start_time:.2f} seconds")
- return Llama(model, tokenizer, model_args)
+ return Llama(model, tokenizer, model_args, llama_model)
- def __init__(self, model: Transformer, tokenizer: Tokenizer, args: ModelArgs):
+ def __init__(
+ self,
+ model: Transformer,
+ tokenizer: Tokenizer,
+ args: ModelArgs,
+ llama_model: str,
+ ):
self.args = args
self.model = model
self.tokenizer = tokenizer
self.formatter = ChatFormat(tokenizer)
+ self.llama_model = llama_model
@torch.inference_mode()
def generate(
@@ -369,7 +377,7 @@ class Llama:
self,
request: ChatCompletionRequest,
) -> Generator:
- messages = chat_completion_request_to_messages(request)
+ messages = chat_completion_request_to_messages(request, self.llama_model)
sampling_params = request.sampling_params
max_gen_len = sampling_params.max_tokens
diff --git a/llama_stack/providers/inline/inference/meta_reference/inference.py b/llama_stack/providers/inline/inference/meta_reference/inference.py
index 2fdc8f2d5..4f5c0c8c2 100644
--- a/llama_stack/providers/inline/inference/meta_reference/inference.py
+++ b/llama_stack/providers/inline/inference/meta_reference/inference.py
@@ -11,9 +11,11 @@ from typing import AsyncGenerator, List
from llama_models.sku_list import resolve_model
from llama_models.llama3.api.datatypes import * # noqa: F403
-from llama_stack.apis.inference import * # noqa: F403
-from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
+from llama_stack.providers.utils.inference.model_registry import build_model_alias
+from llama_stack.apis.inference import * # noqa: F403
+from llama_stack.providers.datatypes import ModelsProtocolPrivate
+from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.providers.utils.inference.prompt_adapter import (
convert_image_media_to_url,
request_has_media,
@@ -28,10 +30,19 @@ from .model_parallel import LlamaModelParallelGenerator
SEMAPHORE = asyncio.Semaphore(1)
-class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
+class MetaReferenceInferenceImpl(Inference, ModelRegistryHelper, ModelsProtocolPrivate):
def __init__(self, config: MetaReferenceInferenceConfig) -> None:
self.config = config
model = resolve_model(config.model)
+ ModelRegistryHelper.__init__(
+ self,
+ [
+ build_model_alias(
+ model.descriptor(),
+ model.core_model_id.value,
+ )
+ ],
+ )
if model is None:
raise RuntimeError(f"Unknown model: {config.model}, Run `llama model list`")
self.model = model
@@ -45,12 +56,6 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
else:
self.generator = Llama.build(self.config)
- async def register_model(self, model: Model) -> None:
- if model.identifier != self.model.descriptor():
- raise ValueError(
- f"Model mismatch: {model.identifier} != {self.model.descriptor()}"
- )
-
async def shutdown(self) -> None:
if self.config.create_distributed_process_group:
self.generator.stop()
@@ -68,7 +73,7 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
async def completion(
self,
- model: str,
+ model_id: str,
content: InterleavedTextMedia,
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
@@ -79,7 +84,7 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}"
request = CompletionRequest(
- model=model,
+ model=model_id,
content=content,
sampling_params=sampling_params,
response_format=response_format,
@@ -186,7 +191,7 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
async def chat_completion(
self,
- model: str,
+ model_id: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
@@ -201,7 +206,7 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
# wrapper request to make it easier to pass around (internal only, not exposed to API)
request = ChatCompletionRequest(
- model=model,
+ model=model_id,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
@@ -386,7 +391,7 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
async def embeddings(
self,
- model: str,
+ model_id: str,
contents: List[InterleavedTextMedia],
) -> EmbeddingsResponse:
raise NotImplementedError()
diff --git a/llama_stack/providers/inline/inference/vllm/vllm.py b/llama_stack/providers/inline/inference/vllm/vllm.py
index 3b1a0dd50..8869cc07f 100644
--- a/llama_stack/providers/inline/inference/vllm/vllm.py
+++ b/llama_stack/providers/inline/inference/vllm/vllm.py
@@ -110,7 +110,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
async def completion(
self,
- model: str,
+ model_id: str,
content: InterleavedTextMedia,
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
@@ -120,7 +120,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
log.info("vLLM completion")
messages = [UserMessage(content=content)]
return self.chat_completion(
- model=model,
+ model=model_id,
messages=messages,
sampling_params=sampling_params,
stream=stream,
@@ -129,7 +129,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
async def chat_completion(
self,
- model: str,
+ model_id: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
tools: Optional[List[ToolDefinition]] = None,
@@ -144,7 +144,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
assert self.engine is not None
request = ChatCompletionRequest(
- model=model,
+ model=model_id,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
@@ -215,7 +215,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
yield chunk
async def embeddings(
- self, model: str, contents: list[InterleavedTextMedia]
+ self, model_id: str, contents: list[InterleavedTextMedia]
) -> EmbeddingsResponse:
log.info("vLLM embeddings")
# TODO
diff --git a/llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/llm_as_judge_scoring_fn.py b/llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/llm_as_judge_scoring_fn.py
index a950f35f9..4b43de93f 100644
--- a/llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/llm_as_judge_scoring_fn.py
+++ b/llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/llm_as_judge_scoring_fn.py
@@ -62,7 +62,7 @@ class LlmAsJudgeScoringFn(BaseScoringFn):
)
judge_response = await self.inference_api.chat_completion(
- model=fn_def.params.judge_model,
+ model_id=fn_def.params.judge_model,
messages=[
{
"role": "user",
diff --git a/llama_stack/providers/remote/inference/bedrock/bedrock.py b/llama_stack/providers/remote/inference/bedrock/bedrock.py
index d9f82c611..f575d9dc3 100644
--- a/llama_stack/providers/remote/inference/bedrock/bedrock.py
+++ b/llama_stack/providers/remote/inference/bedrock/bedrock.py
@@ -7,11 +7,15 @@
from typing import * # noqa: F403
from botocore.client import BaseClient
+from llama_models.datatypes import CoreModelId
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.tokenizer import Tokenizer
-from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
+from llama_stack.providers.utils.inference.model_registry import (
+ build_model_alias,
+ ModelRegistryHelper,
+)
from llama_stack.apis.inference import * # noqa: F403
@@ -19,19 +23,26 @@ from llama_stack.providers.remote.inference.bedrock.config import BedrockConfig
from llama_stack.providers.utils.bedrock.client import create_bedrock_client
-BEDROCK_SUPPORTED_MODELS = {
- "Llama3.1-8B-Instruct": "meta.llama3-1-8b-instruct-v1:0",
- "Llama3.1-70B-Instruct": "meta.llama3-1-70b-instruct-v1:0",
- "Llama3.1-405B-Instruct": "meta.llama3-1-405b-instruct-v1:0",
-}
+model_aliases = [
+ build_model_alias(
+ "meta.llama3-1-8b-instruct-v1:0",
+ CoreModelId.llama3_1_8b_instruct.value,
+ ),
+ build_model_alias(
+ "meta.llama3-1-70b-instruct-v1:0",
+ CoreModelId.llama3_1_70b_instruct.value,
+ ),
+ build_model_alias(
+ "meta.llama3-1-405b-instruct-v1:0",
+ CoreModelId.llama3_1_405b_instruct.value,
+ ),
+]
# NOTE: this is not quite tested after the recent refactors
class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
def __init__(self, config: BedrockConfig) -> None:
- ModelRegistryHelper.__init__(
- self, stack_to_provider_models_map=BEDROCK_SUPPORTED_MODELS
- )
+ ModelRegistryHelper.__init__(self, model_aliases)
self._config = config
self._client = create_bedrock_client(config)
@@ -49,7 +60,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
async def completion(
self,
- model: str,
+ model_id: str,
content: InterleavedTextMedia,
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
@@ -286,7 +297,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
async def chat_completion(
self,
- model: str,
+ model_id: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
@@ -298,8 +309,9 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
) -> Union[
ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]
]:
+ model = await self.model_store.get_model(model_id)
request = ChatCompletionRequest(
- model=model,
+ model=model.provider_resource_id,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
@@ -404,7 +416,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
pass
def _get_params_for_chat_completion(self, request: ChatCompletionRequest) -> Dict:
- bedrock_model = self.map_to_provider_model(request.model)
+ bedrock_model = request.model
inference_config = BedrockInferenceAdapter.get_bedrock_inference_config(
request.sampling_params
)
@@ -433,7 +445,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
async def embeddings(
self,
- model: str,
+ model_id: str,
contents: List[InterleavedTextMedia],
) -> EmbeddingsResponse:
raise NotImplementedError()
diff --git a/llama_stack/providers/remote/inference/databricks/databricks.py b/llama_stack/providers/remote/inference/databricks/databricks.py
index f12ecb7f5..0ebb625bc 100644
--- a/llama_stack/providers/remote/inference/databricks/databricks.py
+++ b/llama_stack/providers/remote/inference/databricks/databricks.py
@@ -6,6 +6,8 @@
from typing import AsyncGenerator
+from llama_models.datatypes import CoreModelId
+
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import Message
@@ -15,7 +17,10 @@ from openai import OpenAI
from llama_stack.apis.inference import * # noqa: F403
-from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
+from llama_stack.providers.utils.inference.model_registry import (
+ build_model_alias,
+ ModelRegistryHelper,
+)
from llama_stack.providers.utils.inference.openai_compat import (
get_sampling_options,
process_chat_completion_response,
@@ -28,16 +33,23 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
from .config import DatabricksImplConfig
-DATABRICKS_SUPPORTED_MODELS = {
- "Llama3.1-70B-Instruct": "databricks-meta-llama-3-1-70b-instruct",
- "Llama3.1-405B-Instruct": "databricks-meta-llama-3-1-405b-instruct",
-}
+model_aliases = [
+ build_model_alias(
+ "databricks-meta-llama-3-1-70b-instruct",
+ CoreModelId.llama3_1_70b_instruct.value,
+ ),
+ build_model_alias(
+ "databricks-meta-llama-3-1-405b-instruct",
+ CoreModelId.llama3_1_405b_instruct.value,
+ ),
+]
class DatabricksInferenceAdapter(ModelRegistryHelper, Inference):
def __init__(self, config: DatabricksImplConfig) -> None:
ModelRegistryHelper.__init__(
- self, stack_to_provider_models_map=DATABRICKS_SUPPORTED_MODELS
+ self,
+ model_aliases=model_aliases,
)
self.config = config
self.formatter = ChatFormat(Tokenizer.get_instance())
@@ -113,8 +125,10 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference):
def _get_params(self, request: ChatCompletionRequest) -> dict:
return {
- "model": self.map_to_provider_model(request.model),
- "prompt": chat_completion_request_to_prompt(request, self.formatter),
+ "model": request.model,
+ "prompt": chat_completion_request_to_prompt(
+ request, self.get_llama_model(request.model), self.formatter
+ ),
"stream": request.stream,
**get_sampling_options(request.sampling_params),
}
diff --git a/llama_stack/providers/remote/inference/fireworks/fireworks.py b/llama_stack/providers/remote/inference/fireworks/fireworks.py
index 57e851c5b..42075eff7 100644
--- a/llama_stack/providers/remote/inference/fireworks/fireworks.py
+++ b/llama_stack/providers/remote/inference/fireworks/fireworks.py
@@ -7,14 +7,17 @@
from typing import AsyncGenerator
from fireworks.client import Fireworks
+from llama_models.datatypes import CoreModelId
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import Message
from llama_models.llama3.api.tokenizer import Tokenizer
-
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.distribution.request_headers import NeedsRequestProviderData
-from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
+from llama_stack.providers.utils.inference.model_registry import (
+ build_model_alias,
+ ModelRegistryHelper,
+)
from llama_stack.providers.utils.inference.openai_compat import (
get_sampling_options,
process_chat_completion_response,
@@ -31,25 +34,52 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
from .config import FireworksImplConfig
-FIREWORKS_SUPPORTED_MODELS = {
- "Llama3.1-8B-Instruct": "fireworks/llama-v3p1-8b-instruct",
- "Llama3.1-70B-Instruct": "fireworks/llama-v3p1-70b-instruct",
- "Llama3.1-405B-Instruct": "fireworks/llama-v3p1-405b-instruct",
- "Llama3.2-1B-Instruct": "fireworks/llama-v3p2-1b-instruct",
- "Llama3.2-3B-Instruct": "fireworks/llama-v3p2-3b-instruct",
- "Llama3.2-11B-Vision-Instruct": "fireworks/llama-v3p2-11b-vision-instruct",
- "Llama3.2-90B-Vision-Instruct": "fireworks/llama-v3p2-90b-vision-instruct",
- "Llama-Guard-3-8B": "fireworks/llama-guard-3-8b",
-}
+
+model_aliases = [
+ build_model_alias(
+ "fireworks/llama-v3p1-8b-instruct",
+ CoreModelId.llama3_1_8b_instruct.value,
+ ),
+ build_model_alias(
+ "fireworks/llama-v3p1-70b-instruct",
+ CoreModelId.llama3_1_70b_instruct.value,
+ ),
+ build_model_alias(
+ "fireworks/llama-v3p1-405b-instruct",
+ CoreModelId.llama3_1_405b_instruct.value,
+ ),
+ build_model_alias(
+ "fireworks/llama-v3p2-1b-instruct",
+ CoreModelId.llama3_2_3b_instruct.value,
+ ),
+ build_model_alias(
+ "fireworks/llama-v3p2-3b-instruct",
+ CoreModelId.llama3_2_11b_vision_instruct.value,
+ ),
+ build_model_alias(
+ "fireworks/llama-v3p2-11b-vision-instruct",
+ CoreModelId.llama3_2_11b_vision_instruct.value,
+ ),
+ build_model_alias(
+ "fireworks/llama-v3p2-90b-vision-instruct",
+ CoreModelId.llama3_2_90b_vision_instruct.value,
+ ),
+ build_model_alias(
+ "fireworks/llama-guard-3-8b",
+ CoreModelId.llama_guard_3_8b.value,
+ ),
+ build_model_alias(
+ "fireworks/llama-guard-3-11b-vision",
+ CoreModelId.llama_guard_3_11b_vision.value,
+ ),
+]
class FireworksInferenceAdapter(
ModelRegistryHelper, Inference, NeedsRequestProviderData
):
def __init__(self, config: FireworksImplConfig) -> None:
- ModelRegistryHelper.__init__(
- self, stack_to_provider_models_map=FIREWORKS_SUPPORTED_MODELS
- )
+ ModelRegistryHelper.__init__(self, model_aliases)
self.config = config
self.formatter = ChatFormat(Tokenizer.get_instance())
@@ -74,15 +104,16 @@ class FireworksInferenceAdapter(
async def completion(
self,
- model: str,
+ model_id: str,
content: InterleavedTextMedia,
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
+ model = await self.model_store.get_model(model_id)
request = CompletionRequest(
- model=model,
+ model=model.provider_resource_id,
content=content,
sampling_params=sampling_params,
response_format=response_format,
@@ -138,7 +169,7 @@ class FireworksInferenceAdapter(
async def chat_completion(
self,
- model: str,
+ model_id: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
tools: Optional[List[ToolDefinition]] = None,
@@ -148,8 +179,9 @@ class FireworksInferenceAdapter(
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
+ model = await self.model_store.get_model(model_id)
request = ChatCompletionRequest(
- model=model,
+ model=model.provider_resource_id,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
@@ -207,7 +239,7 @@ class FireworksInferenceAdapter(
]
else:
input_dict["prompt"] = chat_completion_request_to_prompt(
- request, self.formatter
+ request, self.get_llama_model(request.model), self.formatter
)
else:
assert (
@@ -221,7 +253,7 @@ class FireworksInferenceAdapter(
input_dict["prompt"] = input_dict["prompt"][len("<|begin_of_text|>") :]
return {
- "model": self.map_to_provider_model(request.model),
+ "model": request.model,
**input_dict,
"stream": request.stream,
**self._build_options(request.sampling_params, request.response_format),
@@ -229,7 +261,7 @@ class FireworksInferenceAdapter(
async def embeddings(
self,
- model: str,
+ model_id: str,
contents: List[InterleavedTextMedia],
) -> EmbeddingsResponse:
raise NotImplementedError()
diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py
index 938d05c08..99f74572e 100644
--- a/llama_stack/providers/remote/inference/ollama/ollama.py
+++ b/llama_stack/providers/remote/inference/ollama/ollama.py
@@ -7,15 +7,20 @@
from typing import AsyncGenerator
import httpx
+from llama_models.datatypes import CoreModelId
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import Message
from llama_models.llama3.api.tokenizer import Tokenizer
-
from ollama import AsyncClient
+from llama_stack.providers.utils.inference.model_registry import (
+ build_model_alias,
+ ModelRegistryHelper,
+)
+
from llama_stack.apis.inference import * # noqa: F403
-from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
+from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.utils.inference.openai_compat import (
get_sampling_options,
@@ -33,19 +38,45 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
request_has_media,
)
-OLLAMA_SUPPORTED_MODELS = {
- "Llama3.1-8B-Instruct": "llama3.1:8b-instruct-fp16",
- "Llama3.1-70B-Instruct": "llama3.1:70b-instruct-fp16",
- "Llama3.2-1B-Instruct": "llama3.2:1b-instruct-fp16",
- "Llama3.2-3B-Instruct": "llama3.2:3b-instruct-fp16",
- "Llama-Guard-3-8B": "llama-guard3:8b",
- "Llama-Guard-3-1B": "llama-guard3:1b",
- "Llama3.2-11B-Vision-Instruct": "x/llama3.2-vision:11b-instruct-fp16",
-}
+
+model_aliases = [
+ build_model_alias(
+ "llama3.1:8b-instruct-fp16",
+ CoreModelId.llama3_1_8b_instruct.value,
+ ),
+ build_model_alias(
+ "llama3.1:70b-instruct-fp16",
+ CoreModelId.llama3_1_70b_instruct.value,
+ ),
+ build_model_alias(
+ "llama3.2:1b-instruct-fp16",
+ CoreModelId.llama3_2_1b_instruct.value,
+ ),
+ build_model_alias(
+ "llama3.2:3b-instruct-fp16",
+ CoreModelId.llama3_2_3b_instruct.value,
+ ),
+ build_model_alias(
+ "llama-guard3:8b",
+ CoreModelId.llama_guard_3_8b.value,
+ ),
+ build_model_alias(
+ "llama-guard3:1b",
+ CoreModelId.llama_guard_3_1b.value,
+ ),
+ build_model_alias(
+ "x/llama3.2-vision:11b-instruct-fp16",
+ CoreModelId.llama3_2_11b_vision_instruct.value,
+ ),
+]
-class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
+class OllamaInferenceAdapter(Inference, ModelRegistryHelper, ModelsProtocolPrivate):
def __init__(self, url: str) -> None:
+ ModelRegistryHelper.__init__(
+ self,
+ model_aliases=model_aliases,
+ )
self.url = url
self.formatter = ChatFormat(Tokenizer.get_instance())
@@ -65,44 +96,18 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
async def shutdown(self) -> None:
pass
- async def register_model(self, model: Model) -> None:
- if model.identifier not in OLLAMA_SUPPORTED_MODELS:
- raise ValueError(f"Model {model.identifier} is not supported by Ollama")
-
- async def list_models(self) -> List[Model]:
- ollama_to_llama = {v: k for k, v in OLLAMA_SUPPORTED_MODELS.items()}
-
- ret = []
- res = await self.client.ps()
- for r in res["models"]:
- if r["model"] not in ollama_to_llama:
- print(f"Ollama is running a model unknown to Llama Stack: {r['model']}")
- continue
-
- llama_model = ollama_to_llama[r["model"]]
- print(f"Found model {llama_model} in Ollama")
- ret.append(
- Model(
- identifier=llama_model,
- metadata={
- "ollama_model": r["model"],
- },
- )
- )
-
- return ret
-
async def completion(
self,
- model: str,
+ model_id: str,
content: InterleavedTextMedia,
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
+ model = await self.model_store.get_model(model_id)
request = CompletionRequest(
- model=model,
+ model=model.provider_resource_id,
content=content,
sampling_params=sampling_params,
stream=stream,
@@ -148,7 +153,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
async def chat_completion(
self,
- model: str,
+ model_id: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
@@ -158,8 +163,10 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
+ model = await self.model_store.get_model(model_id)
+ print(f"model={model}")
request = ChatCompletionRequest(
- model=model,
+ model=model.provider_resource_id,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
@@ -197,7 +204,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
else:
input_dict["raw"] = True
input_dict["prompt"] = chat_completion_request_to_prompt(
- request, self.formatter
+ request, self.get_llama_model(request.model), self.formatter
)
else:
assert (
@@ -207,7 +214,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
input_dict["raw"] = True
return {
- "model": OLLAMA_SUPPORTED_MODELS[request.model],
+ "model": request.model,
**input_dict,
"options": sampling_options,
"stream": request.stream,
@@ -271,7 +278,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
async def embeddings(
self,
- model: str,
+ model_id: str,
contents: List[InterleavedTextMedia],
) -> EmbeddingsResponse:
raise NotImplementedError()
diff --git a/llama_stack/providers/remote/inference/together/together.py b/llama_stack/providers/remote/inference/together/together.py
index 28a566415..aae34bb87 100644
--- a/llama_stack/providers/remote/inference/together/together.py
+++ b/llama_stack/providers/remote/inference/together/together.py
@@ -6,6 +6,8 @@
from typing import AsyncGenerator
+from llama_models.datatypes import CoreModelId
+
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import Message
@@ -15,7 +17,10 @@ from together import Together
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.distribution.request_headers import NeedsRequestProviderData
-from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
+from llama_stack.providers.utils.inference.model_registry import (
+ build_model_alias,
+ ModelRegistryHelper,
+)
from llama_stack.providers.utils.inference.openai_compat import (
get_sampling_options,
process_chat_completion_response,
@@ -33,25 +38,47 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
from .config import TogetherImplConfig
-TOGETHER_SUPPORTED_MODELS = {
- "Llama3.1-8B-Instruct": "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo",
- "Llama3.1-70B-Instruct": "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
- "Llama3.1-405B-Instruct": "meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo",
- "Llama3.2-3B-Instruct": "meta-llama/Llama-3.2-3B-Instruct-Turbo",
- "Llama3.2-11B-Vision-Instruct": "meta-llama/Llama-3.2-11B-Vision-Instruct-Turbo",
- "Llama3.2-90B-Vision-Instruct": "meta-llama/Llama-3.2-90B-Vision-Instruct-Turbo",
- "Llama-Guard-3-8B": "meta-llama/Meta-Llama-Guard-3-8B",
- "Llama-Guard-3-11B-Vision": "meta-llama/Llama-Guard-3-11B-Vision-Turbo",
-}
+model_aliases = [
+ build_model_alias(
+ "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo",
+ CoreModelId.llama3_1_8b_instruct.value,
+ ),
+ build_model_alias(
+ "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
+ CoreModelId.llama3_1_70b_instruct.value,
+ ),
+ build_model_alias(
+ "meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo",
+ CoreModelId.llama3_1_405b_instruct.value,
+ ),
+ build_model_alias(
+ "meta-llama/Llama-3.2-3B-Instruct-Turbo",
+ CoreModelId.llama3_2_3b_instruct.value,
+ ),
+ build_model_alias(
+ "meta-llama/Llama-3.2-11B-Vision-Instruct-Turbo",
+ CoreModelId.llama3_2_11b_vision_instruct.value,
+ ),
+ build_model_alias(
+ "meta-llama/Llama-3.2-90B-Vision-Instruct-Turbo",
+ CoreModelId.llama3_2_90b_vision_instruct.value,
+ ),
+ build_model_alias(
+ "meta-llama/Meta-Llama-Guard-3-8B",
+ CoreModelId.llama_guard_3_8b.value,
+ ),
+ build_model_alias(
+ "meta-llama/Llama-Guard-3-11B-Vision-Turbo",
+ CoreModelId.llama_guard_3_11b_vision.value,
+ ),
+]
class TogetherInferenceAdapter(
ModelRegistryHelper, Inference, NeedsRequestProviderData
):
def __init__(self, config: TogetherImplConfig) -> None:
- ModelRegistryHelper.__init__(
- self, stack_to_provider_models_map=TOGETHER_SUPPORTED_MODELS
- )
+ ModelRegistryHelper.__init__(self, model_aliases)
self.config = config
self.formatter = ChatFormat(Tokenizer.get_instance())
@@ -63,15 +90,16 @@ class TogetherInferenceAdapter(
async def completion(
self,
- model: str,
+ model_id: str,
content: InterleavedTextMedia,
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
+ model = await self.model_store.get_model(model_id)
request = CompletionRequest(
- model=model,
+ model=model.provider_resource_id,
content=content,
sampling_params=sampling_params,
response_format=response_format,
@@ -135,7 +163,7 @@ class TogetherInferenceAdapter(
async def chat_completion(
self,
- model: str,
+ model_id: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
tools: Optional[List[ToolDefinition]] = None,
@@ -145,8 +173,9 @@ class TogetherInferenceAdapter(
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
+ model = await self.model_store.get_model(model_id)
request = ChatCompletionRequest(
- model=model,
+ model=model.provider_resource_id,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
@@ -204,7 +233,7 @@ class TogetherInferenceAdapter(
]
else:
input_dict["prompt"] = chat_completion_request_to_prompt(
- request, self.formatter
+ request, self.get_llama_model(request.model), self.formatter
)
else:
assert (
@@ -213,7 +242,7 @@ class TogetherInferenceAdapter(
input_dict["prompt"] = completion_request_to_prompt(request, self.formatter)
return {
- "model": self.map_to_provider_model(request.model),
+ "model": request.model,
**input_dict,
"stream": request.stream,
**self._build_options(request.sampling_params, request.response_format),
@@ -221,7 +250,7 @@ class TogetherInferenceAdapter(
async def embeddings(
self,
- model: str,
+ model_id: str,
contents: List[InterleavedTextMedia],
) -> EmbeddingsResponse:
raise NotImplementedError()
diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py
index bd7f5073c..e5eb6e1ea 100644
--- a/llama_stack/providers/remote/inference/vllm/vllm.py
+++ b/llama_stack/providers/remote/inference/vllm/vllm.py
@@ -8,13 +8,17 @@ from typing import AsyncGenerator
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import Message
from llama_models.llama3.api.tokenizer import Tokenizer
-from llama_models.sku_list import all_registered_models, resolve_model
+from llama_models.sku_list import all_registered_models
from openai import OpenAI
from llama_stack.apis.inference import * # noqa: F403
-from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
+from llama_stack.providers.datatypes import ModelsProtocolPrivate
+from llama_stack.providers.utils.inference.model_registry import (
+ build_model_alias,
+ ModelRegistryHelper,
+)
from llama_stack.providers.utils.inference.openai_compat import (
get_sampling_options,
process_chat_completion_response,
@@ -30,44 +34,36 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
from .config import VLLMInferenceAdapterConfig
-class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
+def build_model_aliases():
+ return [
+ build_model_alias(
+ model.huggingface_repo,
+ model.descriptor(),
+ )
+ for model in all_registered_models()
+ if model.huggingface_repo
+ ]
+
+
+class VLLMInferenceAdapter(Inference, ModelRegistryHelper, ModelsProtocolPrivate):
def __init__(self, config: VLLMInferenceAdapterConfig) -> None:
+ ModelRegistryHelper.__init__(
+ self,
+ model_aliases=build_model_aliases(),
+ )
self.config = config
self.formatter = ChatFormat(Tokenizer.get_instance())
self.client = None
- self.huggingface_repo_to_llama_model_id = {
- model.huggingface_repo: model.descriptor()
- for model in all_registered_models()
- if model.huggingface_repo
- }
async def initialize(self) -> None:
self.client = OpenAI(base_url=self.config.url, api_key=self.config.api_token)
- async def register_model(self, model: Model) -> None:
- for running_model in self.client.models.list():
- repo = running_model.id
- if repo not in self.huggingface_repo_to_llama_model_id:
- print(f"Unknown model served by vllm: {repo}")
- continue
-
- identifier = self.huggingface_repo_to_llama_model_id[repo]
- if identifier == model.provider_resource_id:
- print(
- f"Verified that model {model.provider_resource_id} is being served by vLLM"
- )
- return
-
- raise ValueError(
- f"Model {model.provider_resource_id} is not being served by vLLM"
- )
-
async def shutdown(self) -> None:
pass
async def completion(
self,
- model: str,
+ model_id: str,
content: InterleavedTextMedia,
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
@@ -78,7 +74,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
async def chat_completion(
self,
- model: str,
+ model_id: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
@@ -88,8 +84,9 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
+ model = await self.model_store.get_model(model_id)
request = ChatCompletionRequest(
- model=model,
+ model=model.provider_resource_id,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
@@ -141,10 +138,6 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
if "max_tokens" not in options:
options["max_tokens"] = self.config.max_tokens
- model = resolve_model(request.model)
- if model is None:
- raise ValueError(f"Unknown model: {request.model}")
-
input_dict = {}
media_present = request_has_media(request)
if isinstance(request, ChatCompletionRequest):
@@ -156,16 +149,20 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
]
else:
input_dict["prompt"] = chat_completion_request_to_prompt(
- request, self.formatter
+ request, self.get_llama_model(request.model), self.formatter
)
else:
assert (
not media_present
), "Together does not support media for Completion requests"
- input_dict["prompt"] = completion_request_to_prompt(request, self.formatter)
+ input_dict["prompt"] = completion_request_to_prompt(
+ request,
+ self.get_llama_model(request.model),
+ self.formatter,
+ )
return {
- "model": model.huggingface_repo,
+ "model": request.model,
**input_dict,
"stream": request.stream,
**options,
@@ -173,7 +170,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
async def embeddings(
self,
- model: str,
+ model_id: str,
contents: List[InterleavedTextMedia],
) -> EmbeddingsResponse:
raise NotImplementedError()
diff --git a/llama_stack/providers/tests/inference/fixtures.py b/llama_stack/providers/tests/inference/fixtures.py
index d35ebab28..f6f2a30e8 100644
--- a/llama_stack/providers/tests/inference/fixtures.py
+++ b/llama_stack/providers/tests/inference/fixtures.py
@@ -49,7 +49,7 @@ def inference_meta_reference(inference_model) -> ProviderFixture:
providers=[
Provider(
provider_id=f"meta-reference-{i}",
- provider_type="meta-reference",
+ provider_type="inline::meta-reference",
config=MetaReferenceInferenceConfig(
model=m,
max_seq_len=4096,
@@ -142,6 +142,31 @@ def inference_bedrock() -> ProviderFixture:
)
+def get_model_short_name(model_name: str) -> str:
+ """Convert model name to a short test identifier.
+
+ Args:
+ model_name: Full model name like "Llama3.1-8B-Instruct"
+
+ Returns:
+ Short name like "llama_8b" suitable for test markers
+ """
+ model_name = model_name.lower()
+ if "vision" in model_name:
+ return "llama_vision"
+ elif "3b" in model_name:
+ return "llama_3b"
+ elif "8b" in model_name:
+ return "llama_8b"
+ else:
+ return model_name.replace(".", "_").replace("-", "_")
+
+
+@pytest.fixture(scope="session")
+def model_id(inference_model) -> str:
+ return get_model_short_name(inference_model)
+
+
INFERENCE_FIXTURES = [
"meta_reference",
"ollama",
diff --git a/llama_stack/providers/tests/inference/test_text_inference.py b/llama_stack/providers/tests/inference/test_text_inference.py
index e7bfbc135..70047a61f 100644
--- a/llama_stack/providers/tests/inference/test_text_inference.py
+++ b/llama_stack/providers/tests/inference/test_text_inference.py
@@ -96,7 +96,7 @@ class TestInference:
response = await inference_impl.completion(
content="Micheael Jordan is born in ",
stream=False,
- model=inference_model,
+ model_id=inference_model,
sampling_params=SamplingParams(
max_tokens=50,
),
@@ -110,7 +110,7 @@ class TestInference:
async for r in await inference_impl.completion(
content="Roses are red,",
stream=True,
- model=inference_model,
+ model_id=inference_model,
sampling_params=SamplingParams(
max_tokens=50,
),
@@ -171,7 +171,7 @@ class TestInference:
):
inference_impl, _ = inference_stack
response = await inference_impl.chat_completion(
- model=inference_model,
+ model_id=inference_model,
messages=sample_messages,
stream=False,
**common_params,
@@ -204,7 +204,7 @@ class TestInference:
num_seasons_in_nba: int
response = await inference_impl.chat_completion(
- model=inference_model,
+ model_id=inference_model,
messages=[
SystemMessage(content="You are a helpful assistant."),
UserMessage(content="Please give me information about Michael Jordan."),
@@ -227,7 +227,7 @@ class TestInference:
assert answer.num_seasons_in_nba == 15
response = await inference_impl.chat_completion(
- model=inference_model,
+ model_id=inference_model,
messages=[
SystemMessage(content="You are a helpful assistant."),
UserMessage(content="Please give me information about Michael Jordan."),
@@ -250,7 +250,7 @@ class TestInference:
response = [
r
async for r in await inference_impl.chat_completion(
- model=inference_model,
+ model_id=inference_model,
messages=sample_messages,
stream=True,
**common_params,
@@ -286,7 +286,7 @@ class TestInference:
]
response = await inference_impl.chat_completion(
- model=inference_model,
+ model_id=inference_model,
messages=messages,
tools=[sample_tool_definition],
stream=False,
@@ -327,7 +327,7 @@ class TestInference:
response = [
r
async for r in await inference_impl.chat_completion(
- model=inference_model,
+ model_id=inference_model,
messages=messages,
tools=[sample_tool_definition],
stream=True,
diff --git a/llama_stack/providers/utils/inference/model_registry.py b/llama_stack/providers/utils/inference/model_registry.py
index 141e4af31..7120e9e97 100644
--- a/llama_stack/providers/utils/inference/model_registry.py
+++ b/llama_stack/providers/utils/inference/model_registry.py
@@ -4,32 +4,61 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
-from typing import Dict
+from collections import namedtuple
+from typing import List, Optional
-from llama_models.sku_list import resolve_model
+from llama_models.sku_list import all_registered_models
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
+ModelAlias = namedtuple("ModelAlias", ["provider_model_id", "aliases", "llama_model"])
+
+
+def get_huggingface_repo(model_descriptor: str) -> Optional[str]:
+ for model in all_registered_models():
+ if model.descriptor() == model_descriptor:
+ return model.huggingface_repo
+ return None
+
+
+def build_model_alias(provider_model_id: str, model_descriptor: str) -> ModelAlias:
+ return ModelAlias(
+ provider_model_id=provider_model_id,
+ aliases=[
+ model_descriptor,
+ get_huggingface_repo(model_descriptor),
+ ],
+ llama_model=model_descriptor,
+ )
+
class ModelRegistryHelper(ModelsProtocolPrivate):
+ def __init__(self, model_aliases: List[ModelAlias]):
+ self.alias_to_provider_id_map = {}
+ self.provider_id_to_llama_model_map = {}
+ for alias_obj in model_aliases:
+ for alias in alias_obj.aliases:
+ self.alias_to_provider_id_map[alias] = alias_obj.provider_model_id
+ # also add a mapping from provider model id to itself for easy lookup
+ self.alias_to_provider_id_map[alias_obj.provider_model_id] = (
+ alias_obj.provider_model_id
+ )
+ self.provider_id_to_llama_model_map[alias_obj.provider_model_id] = (
+ alias_obj.llama_model
+ )
- def __init__(self, stack_to_provider_models_map: Dict[str, str]):
- self.stack_to_provider_models_map = stack_to_provider_models_map
-
- def map_to_provider_model(self, identifier: str) -> str:
- model = resolve_model(identifier)
- if not model:
+ def get_provider_model_id(self, identifier: str) -> str:
+ if identifier in self.alias_to_provider_id_map:
+ return self.alias_to_provider_id_map[identifier]
+ else:
raise ValueError(f"Unknown model: `{identifier}`")
- if identifier not in self.stack_to_provider_models_map:
- raise ValueError(
- f"Model {identifier} not found in map {self.stack_to_provider_models_map}"
- )
+ def get_llama_model(self, provider_model_id: str) -> str:
+ return self.provider_id_to_llama_model_map[provider_model_id]
- return self.stack_to_provider_models_map[identifier]
+ async def register_model(self, model: Model) -> Model:
+ model.provider_resource_id = self.get_provider_model_id(
+ model.provider_resource_id
+ )
- async def register_model(self, model: Model) -> None:
- if model.identifier not in self.stack_to_provider_models_map:
- raise ValueError(
- f"Unsupported model {model.identifier}. Supported models: {self.stack_to_provider_models_map.keys()}"
- )
+ return model
diff --git a/llama_stack/providers/utils/inference/prompt_adapter.py b/llama_stack/providers/utils/inference/prompt_adapter.py
index 45e43c898..2df04664f 100644
--- a/llama_stack/providers/utils/inference/prompt_adapter.py
+++ b/llama_stack/providers/utils/inference/prompt_adapter.py
@@ -147,17 +147,17 @@ def augment_content_with_response_format_prompt(response_format, content):
def chat_completion_request_to_prompt(
- request: ChatCompletionRequest, formatter: ChatFormat
+ request: ChatCompletionRequest, llama_model: str, formatter: ChatFormat
) -> str:
- messages = chat_completion_request_to_messages(request)
+ messages = chat_completion_request_to_messages(request, llama_model)
model_input = formatter.encode_dialog_prompt(messages)
return formatter.tokenizer.decode(model_input.tokens)
def chat_completion_request_to_model_input_info(
- request: ChatCompletionRequest, formatter: ChatFormat
+ request: ChatCompletionRequest, llama_model: str, formatter: ChatFormat
) -> Tuple[str, int]:
- messages = chat_completion_request_to_messages(request)
+ messages = chat_completion_request_to_messages(request, llama_model)
model_input = formatter.encode_dialog_prompt(messages)
return (
formatter.tokenizer.decode(model_input.tokens),
@@ -167,14 +167,15 @@ def chat_completion_request_to_model_input_info(
def chat_completion_request_to_messages(
request: ChatCompletionRequest,
+ llama_model: str,
) -> List[Message]:
"""Reads chat completion request and augments the messages to handle tools.
For eg. for llama_3_1, add system message with the appropriate tools or
add user messsage for custom tools, etc.
"""
- model = resolve_model(request.model)
+ model = resolve_model(llama_model)
if model is None:
- cprint(f"Could not resolve model {request.model}", color="red")
+ cprint(f"Could not resolve model {llama_model}", color="red")
return request.messages
if model.descriptor() not in supported_inference_models():