Inference to use provider resource id to register and validate (#428)

This PR changes the way model id gets translated to the final model name
that gets passed through the provider.
Major changes include:
1) Providers are responsible for registering an object and as part of
the registration returning the object with the correct provider specific
name of the model provider_resource_id
2) To help with the common look ups different names a new ModelLookup
class is created.



Tested all inference providers including together, fireworks, vllm,
ollama, meta reference and bedrock
This commit is contained in:
Dinesh Yeduguru 2024-11-12 20:02:00 -08:00 committed by GitHub
parent e51107e019
commit fdff24e77a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
21 changed files with 460 additions and 290 deletions

View file

@ -21,7 +21,7 @@
"info": {
"title": "[DRAFT] Llama Stack Specification",
"version": "0.0.1",
"description": "This is the specification of the llama stack that provides\n a set of endpoints and their corresponding interfaces that are tailored to\n best leverage Llama Models. The specification is still in draft and subject to change.\n Generated at 2024-11-12 11:39:48.665782"
"description": "This is the specification of the llama stack that provides\n a set of endpoints and their corresponding interfaces that are tailored to\n best leverage Llama Models. The specification is still in draft and subject to change.\n Generated at 2024-11-12 15:47:15.607543"
},
"servers": [
{
@ -2856,7 +2856,7 @@
"ChatCompletionRequest": {
"type": "object",
"properties": {
"model": {
"model_id": {
"type": "string"
},
"messages": {
@ -2993,7 +2993,7 @@
},
"additionalProperties": false,
"required": [
"model",
"model_id",
"messages"
]
},
@ -3120,7 +3120,7 @@
"CompletionRequest": {
"type": "object",
"properties": {
"model": {
"model_id": {
"type": "string"
},
"content": {
@ -3249,7 +3249,7 @@
},
"additionalProperties": false,
"required": [
"model",
"model_id",
"content"
]
},
@ -4552,7 +4552,7 @@
"EmbeddingsRequest": {
"type": "object",
"properties": {
"model": {
"model_id": {
"type": "string"
},
"contents": {
@ -4584,7 +4584,7 @@
},
"additionalProperties": false,
"required": [
"model",
"model_id",
"contents"
]
},
@ -7837,34 +7837,10 @@
],
"tags": [
{
"name": "MemoryBanks"
"name": "Safety"
},
{
"name": "BatchInference"
},
{
"name": "Agents"
},
{
"name": "Inference"
},
{
"name": "DatasetIO"
},
{
"name": "Eval"
},
{
"name": "Models"
},
{
"name": "PostTraining"
},
{
"name": "ScoringFunctions"
},
{
"name": "Datasets"
"name": "EvalTasks"
},
{
"name": "Shields"
@ -7872,15 +7848,6 @@
{
"name": "Telemetry"
},
{
"name": "Inspect"
},
{
"name": "Safety"
},
{
"name": "SyntheticDataGeneration"
},
{
"name": "Memory"
},
@ -7888,7 +7855,40 @@
"name": "Scoring"
},
{
"name": "EvalTasks"
"name": "ScoringFunctions"
},
{
"name": "SyntheticDataGeneration"
},
{
"name": "Models"
},
{
"name": "Agents"
},
{
"name": "MemoryBanks"
},
{
"name": "DatasetIO"
},
{
"name": "Inference"
},
{
"name": "Datasets"
},
{
"name": "PostTraining"
},
{
"name": "BatchInference"
},
{
"name": "Eval"
},
{
"name": "Inspect"
},
{
"name": "BuiltinTool",

View file

@ -396,7 +396,7 @@ components:
- $ref: '#/components/schemas/ToolResponseMessage'
- $ref: '#/components/schemas/CompletionMessage'
type: array
model:
model_id:
type: string
response_format:
oneOf:
@ -453,7 +453,7 @@ components:
$ref: '#/components/schemas/ToolDefinition'
type: array
required:
- model
- model_id
- messages
type: object
ChatCompletionResponse:
@ -577,7 +577,7 @@ components:
default: 0
type: integer
type: object
model:
model_id:
type: string
response_format:
oneOf:
@ -626,7 +626,7 @@ components:
stream:
type: boolean
required:
- model
- model_id
- content
type: object
CompletionResponse:
@ -903,10 +903,10 @@ components:
- $ref: '#/components/schemas/ImageMedia'
type: array
type: array
model:
model_id:
type: string
required:
- model
- model_id
- contents
type: object
EmbeddingsResponse:
@ -3384,7 +3384,7 @@ info:
description: "This is the specification of the llama stack that provides\n \
\ a set of endpoints and their corresponding interfaces that are tailored\
\ to\n best leverage Llama Models. The specification is still in\
\ draft and subject to change.\n Generated at 2024-11-12 11:39:48.665782"
\ draft and subject to change.\n Generated at 2024-11-12 15:47:15.607543"
title: '[DRAFT] Llama Stack Specification'
version: 0.0.1
jsonSchemaDialect: https://json-schema.org/draft/2020-12/schema
@ -4748,24 +4748,24 @@ security:
servers:
- url: http://any-hosted-llama-stack.com
tags:
- name: MemoryBanks
- name: BatchInference
- name: Agents
- name: Inference
- name: DatasetIO
- name: Eval
- name: Models
- name: PostTraining
- name: ScoringFunctions
- name: Datasets
- name: Safety
- name: EvalTasks
- name: Shields
- name: Telemetry
- name: Inspect
- name: Safety
- name: SyntheticDataGeneration
- name: Memory
- name: Scoring
- name: EvalTasks
- name: ScoringFunctions
- name: SyntheticDataGeneration
- name: Models
- name: Agents
- name: MemoryBanks
- name: DatasetIO
- name: Inference
- name: Datasets
- name: PostTraining
- name: BatchInference
- name: Eval
- name: Inspect
- description: <SchemaDefinition schemaRef="#/components/schemas/BuiltinTool" />
name: BuiltinTool
- description: <SchemaDefinition schemaRef="#/components/schemas/CompletionMessage"

View file

@ -538,7 +538,7 @@ Once the server is set up, we can test it with a client to verify it's working c
$ curl http://localhost:5000/inference/chat_completion \
-H "Content-Type: application/json" \
-d '{
"model": "Llama3.1-8B-Instruct",
"model_id": "Llama3.1-8B-Instruct",
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Write me a 2 sentence poem about the moon"}

View file

@ -226,7 +226,7 @@ class Inference(Protocol):
@webmethod(route="/inference/completion")
async def completion(
self,
model: str,
model_id: str,
content: InterleavedTextMedia,
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
@ -237,7 +237,7 @@ class Inference(Protocol):
@webmethod(route="/inference/chat_completion")
async def chat_completion(
self,
model: str,
model_id: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
# zero-shot tool definitions as input to the model
@ -254,6 +254,6 @@ class Inference(Protocol):
@webmethod(route="/inference/embeddings")
async def embeddings(
self,
model: str,
model_id: str,
contents: List[InterleavedTextMedia],
) -> EmbeddingsResponse: ...

View file

@ -95,7 +95,7 @@ class InferenceRouter(Inference):
async def chat_completion(
self,
model: str,
model_id: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
@ -106,7 +106,7 @@ class InferenceRouter(Inference):
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
params = dict(
model=model,
model_id=model_id,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
@ -116,7 +116,7 @@ class InferenceRouter(Inference):
stream=stream,
logprobs=logprobs,
)
provider = self.routing_table.get_provider_impl(model)
provider = self.routing_table.get_provider_impl(model_id)
if stream:
return (chunk async for chunk in await provider.chat_completion(**params))
else:
@ -124,16 +124,16 @@ class InferenceRouter(Inference):
async def completion(
self,
model: str,
model_id: str,
content: InterleavedTextMedia,
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
provider = self.routing_table.get_provider_impl(model)
provider = self.routing_table.get_provider_impl(model_id)
params = dict(
model=model,
model_id=model_id,
content=content,
sampling_params=sampling_params,
response_format=response_format,
@ -147,11 +147,11 @@ class InferenceRouter(Inference):
async def embeddings(
self,
model: str,
model_id: str,
contents: List[InterleavedTextMedia],
) -> EmbeddingsResponse:
return await self.routing_table.get_provider_impl(model).embeddings(
model=model,
return await self.routing_table.get_provider_impl(model_id).embeddings(
model_id=model_id,
contents=contents,
)

View file

@ -28,7 +28,9 @@ def get_impl_api(p: Any) -> Api:
return p.__provider_spec__.api
async def register_object_with_provider(obj: RoutableObject, p: Any) -> None:
# TODO: this should return the registered object for all APIs
async def register_object_with_provider(obj: RoutableObject, p: Any) -> RoutableObject:
api = get_impl_api(p)
if obj.provider_id == "remote":
@ -42,7 +44,7 @@ async def register_object_with_provider(obj: RoutableObject, p: Any) -> None:
obj.provider_id = ""
if api == Api.inference:
await p.register_model(obj)
return await p.register_model(obj)
elif api == Api.safety:
await p.register_shield(obj)
elif api == Api.memory:
@ -167,7 +169,9 @@ class CommonRoutingTableImpl(RoutingTable):
assert len(objects) == 1
return objects[0]
async def register_object(self, obj: RoutableObjectWithProvider):
async def register_object(
self, obj: RoutableObjectWithProvider
) -> RoutableObjectWithProvider:
# Get existing objects from registry
existing_objects = await self.dist_registry.get(obj.type, obj.identifier)
@ -177,7 +181,7 @@ class CommonRoutingTableImpl(RoutingTable):
print(
f"`{obj.identifier}` already registered with `{existing_obj.provider_id}`"
)
return
return existing_obj
# if provider_id is not specified, pick an arbitrary one from existing entries
if not obj.provider_id and len(self.impls_by_provider_id) > 0:
@ -188,8 +192,15 @@ class CommonRoutingTableImpl(RoutingTable):
p = self.impls_by_provider_id[obj.provider_id]
await register_object_with_provider(obj, p)
await self.dist_registry.register(obj)
registered_obj = await register_object_with_provider(obj, p)
# TODO: This needs to be fixed for all APIs once they return the registered object
if obj.type == ResourceType.model.value:
await self.dist_registry.register(registered_obj)
return registered_obj
else:
await self.dist_registry.register(obj)
return obj
async def get_all_with_type(self, type: str) -> List[RoutableObjectWithProvider]:
objs = await self.dist_registry.get_all()
@ -228,8 +239,8 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
provider_id=provider_id,
metadata=metadata,
)
await self.register_object(model)
return model
registered_model = await self.register_object(model)
return registered_model
class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):

View file

@ -150,7 +150,7 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate):
messages.append(candidate.system_message)
messages += input_messages
response = await self.inference_api.chat_completion(
model=candidate.model,
model_id=candidate.model,
messages=messages,
sampling_params=candidate.sampling_params,
)

View file

@ -86,6 +86,7 @@ class Llama:
and loads the pre-trained model and tokenizer.
"""
model = resolve_model(config.model)
llama_model = model.core_model_id.value
if not torch.distributed.is_initialized():
torch.distributed.init_process_group("nccl")
@ -186,13 +187,20 @@ class Llama:
model.load_state_dict(state_dict, strict=False)
print(f"Loaded in {time.time() - start_time:.2f} seconds")
return Llama(model, tokenizer, model_args)
return Llama(model, tokenizer, model_args, llama_model)
def __init__(self, model: Transformer, tokenizer: Tokenizer, args: ModelArgs):
def __init__(
self,
model: Transformer,
tokenizer: Tokenizer,
args: ModelArgs,
llama_model: str,
):
self.args = args
self.model = model
self.tokenizer = tokenizer
self.formatter = ChatFormat(tokenizer)
self.llama_model = llama_model
@torch.inference_mode()
def generate(
@ -369,7 +377,7 @@ class Llama:
self,
request: ChatCompletionRequest,
) -> Generator:
messages = chat_completion_request_to_messages(request)
messages = chat_completion_request_to_messages(request, self.llama_model)
sampling_params = request.sampling_params
max_gen_len = sampling_params.max_tokens

View file

@ -11,9 +11,11 @@ from typing import AsyncGenerator, List
from llama_models.sku_list import resolve_model
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
from llama_stack.providers.utils.inference.model_registry import build_model_alias
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.providers.utils.inference.prompt_adapter import (
convert_image_media_to_url,
request_has_media,
@ -28,10 +30,19 @@ from .model_parallel import LlamaModelParallelGenerator
SEMAPHORE = asyncio.Semaphore(1)
class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
class MetaReferenceInferenceImpl(Inference, ModelRegistryHelper, ModelsProtocolPrivate):
def __init__(self, config: MetaReferenceInferenceConfig) -> None:
self.config = config
model = resolve_model(config.model)
ModelRegistryHelper.__init__(
self,
[
build_model_alias(
model.descriptor(),
model.core_model_id.value,
)
],
)
if model is None:
raise RuntimeError(f"Unknown model: {config.model}, Run `llama model list`")
self.model = model
@ -45,12 +56,6 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
else:
self.generator = Llama.build(self.config)
async def register_model(self, model: Model) -> None:
if model.identifier != self.model.descriptor():
raise ValueError(
f"Model mismatch: {model.identifier} != {self.model.descriptor()}"
)
async def shutdown(self) -> None:
if self.config.create_distributed_process_group:
self.generator.stop()
@ -68,7 +73,7 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
async def completion(
self,
model: str,
model_id: str,
content: InterleavedTextMedia,
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
@ -79,7 +84,7 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}"
request = CompletionRequest(
model=model,
model=model_id,
content=content,
sampling_params=sampling_params,
response_format=response_format,
@ -186,7 +191,7 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
async def chat_completion(
self,
model: str,
model_id: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
@ -201,7 +206,7 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
# wrapper request to make it easier to pass around (internal only, not exposed to API)
request = ChatCompletionRequest(
model=model,
model=model_id,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
@ -386,7 +391,7 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
async def embeddings(
self,
model: str,
model_id: str,
contents: List[InterleavedTextMedia],
) -> EmbeddingsResponse:
raise NotImplementedError()

View file

@ -110,7 +110,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
async def completion(
self,
model: str,
model_id: str,
content: InterleavedTextMedia,
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
@ -120,7 +120,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
log.info("vLLM completion")
messages = [UserMessage(content=content)]
return self.chat_completion(
model=model,
model=model_id,
messages=messages,
sampling_params=sampling_params,
stream=stream,
@ -129,7 +129,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
async def chat_completion(
self,
model: str,
model_id: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
tools: Optional[List[ToolDefinition]] = None,
@ -144,7 +144,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
assert self.engine is not None
request = ChatCompletionRequest(
model=model,
model=model_id,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
@ -215,7 +215,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
yield chunk
async def embeddings(
self, model: str, contents: list[InterleavedTextMedia]
self, model_id: str, contents: list[InterleavedTextMedia]
) -> EmbeddingsResponse:
log.info("vLLM embeddings")
# TODO

View file

@ -62,7 +62,7 @@ class LlmAsJudgeScoringFn(BaseScoringFn):
)
judge_response = await self.inference_api.chat_completion(
model=fn_def.params.judge_model,
model_id=fn_def.params.judge_model,
messages=[
{
"role": "user",

View file

@ -7,11 +7,15 @@
from typing import * # noqa: F403
from botocore.client import BaseClient
from llama_models.datatypes import CoreModelId
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.providers.utils.inference.model_registry import (
build_model_alias,
ModelRegistryHelper,
)
from llama_stack.apis.inference import * # noqa: F403
@ -19,19 +23,26 @@ from llama_stack.providers.remote.inference.bedrock.config import BedrockConfig
from llama_stack.providers.utils.bedrock.client import create_bedrock_client
BEDROCK_SUPPORTED_MODELS = {
"Llama3.1-8B-Instruct": "meta.llama3-1-8b-instruct-v1:0",
"Llama3.1-70B-Instruct": "meta.llama3-1-70b-instruct-v1:0",
"Llama3.1-405B-Instruct": "meta.llama3-1-405b-instruct-v1:0",
}
model_aliases = [
build_model_alias(
"meta.llama3-1-8b-instruct-v1:0",
CoreModelId.llama3_1_8b_instruct.value,
),
build_model_alias(
"meta.llama3-1-70b-instruct-v1:0",
CoreModelId.llama3_1_70b_instruct.value,
),
build_model_alias(
"meta.llama3-1-405b-instruct-v1:0",
CoreModelId.llama3_1_405b_instruct.value,
),
]
# NOTE: this is not quite tested after the recent refactors
class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
def __init__(self, config: BedrockConfig) -> None:
ModelRegistryHelper.__init__(
self, stack_to_provider_models_map=BEDROCK_SUPPORTED_MODELS
)
ModelRegistryHelper.__init__(self, model_aliases)
self._config = config
self._client = create_bedrock_client(config)
@ -49,7 +60,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
async def completion(
self,
model: str,
model_id: str,
content: InterleavedTextMedia,
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
@ -286,7 +297,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
async def chat_completion(
self,
model: str,
model_id: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
@ -298,8 +309,9 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
) -> Union[
ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]
]:
model = await self.model_store.get_model(model_id)
request = ChatCompletionRequest(
model=model,
model=model.provider_resource_id,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
@ -404,7 +416,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
pass
def _get_params_for_chat_completion(self, request: ChatCompletionRequest) -> Dict:
bedrock_model = self.map_to_provider_model(request.model)
bedrock_model = request.model
inference_config = BedrockInferenceAdapter.get_bedrock_inference_config(
request.sampling_params
)
@ -433,7 +445,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
async def embeddings(
self,
model: str,
model_id: str,
contents: List[InterleavedTextMedia],
) -> EmbeddingsResponse:
raise NotImplementedError()

View file

@ -6,6 +6,8 @@
from typing import AsyncGenerator
from llama_models.datatypes import CoreModelId
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import Message
@ -15,7 +17,10 @@ from openai import OpenAI
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.providers.utils.inference.model_registry import (
build_model_alias,
ModelRegistryHelper,
)
from llama_stack.providers.utils.inference.openai_compat import (
get_sampling_options,
process_chat_completion_response,
@ -28,16 +33,23 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
from .config import DatabricksImplConfig
DATABRICKS_SUPPORTED_MODELS = {
"Llama3.1-70B-Instruct": "databricks-meta-llama-3-1-70b-instruct",
"Llama3.1-405B-Instruct": "databricks-meta-llama-3-1-405b-instruct",
}
model_aliases = [
build_model_alias(
"databricks-meta-llama-3-1-70b-instruct",
CoreModelId.llama3_1_70b_instruct.value,
),
build_model_alias(
"databricks-meta-llama-3-1-405b-instruct",
CoreModelId.llama3_1_405b_instruct.value,
),
]
class DatabricksInferenceAdapter(ModelRegistryHelper, Inference):
def __init__(self, config: DatabricksImplConfig) -> None:
ModelRegistryHelper.__init__(
self, stack_to_provider_models_map=DATABRICKS_SUPPORTED_MODELS
self,
model_aliases=model_aliases,
)
self.config = config
self.formatter = ChatFormat(Tokenizer.get_instance())
@ -113,8 +125,10 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference):
def _get_params(self, request: ChatCompletionRequest) -> dict:
return {
"model": self.map_to_provider_model(request.model),
"prompt": chat_completion_request_to_prompt(request, self.formatter),
"model": request.model,
"prompt": chat_completion_request_to_prompt(
request, self.get_llama_model(request.model), self.formatter
),
"stream": request.stream,
**get_sampling_options(request.sampling_params),
}

View file

@ -7,14 +7,17 @@
from typing import AsyncGenerator
from fireworks.client import Fireworks
from llama_models.datatypes import CoreModelId
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import Message
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.providers.utils.inference.model_registry import (
build_model_alias,
ModelRegistryHelper,
)
from llama_stack.providers.utils.inference.openai_compat import (
get_sampling_options,
process_chat_completion_response,
@ -31,25 +34,52 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
from .config import FireworksImplConfig
FIREWORKS_SUPPORTED_MODELS = {
"Llama3.1-8B-Instruct": "fireworks/llama-v3p1-8b-instruct",
"Llama3.1-70B-Instruct": "fireworks/llama-v3p1-70b-instruct",
"Llama3.1-405B-Instruct": "fireworks/llama-v3p1-405b-instruct",
"Llama3.2-1B-Instruct": "fireworks/llama-v3p2-1b-instruct",
"Llama3.2-3B-Instruct": "fireworks/llama-v3p2-3b-instruct",
"Llama3.2-11B-Vision-Instruct": "fireworks/llama-v3p2-11b-vision-instruct",
"Llama3.2-90B-Vision-Instruct": "fireworks/llama-v3p2-90b-vision-instruct",
"Llama-Guard-3-8B": "fireworks/llama-guard-3-8b",
}
model_aliases = [
build_model_alias(
"fireworks/llama-v3p1-8b-instruct",
CoreModelId.llama3_1_8b_instruct.value,
),
build_model_alias(
"fireworks/llama-v3p1-70b-instruct",
CoreModelId.llama3_1_70b_instruct.value,
),
build_model_alias(
"fireworks/llama-v3p1-405b-instruct",
CoreModelId.llama3_1_405b_instruct.value,
),
build_model_alias(
"fireworks/llama-v3p2-1b-instruct",
CoreModelId.llama3_2_3b_instruct.value,
),
build_model_alias(
"fireworks/llama-v3p2-3b-instruct",
CoreModelId.llama3_2_11b_vision_instruct.value,
),
build_model_alias(
"fireworks/llama-v3p2-11b-vision-instruct",
CoreModelId.llama3_2_11b_vision_instruct.value,
),
build_model_alias(
"fireworks/llama-v3p2-90b-vision-instruct",
CoreModelId.llama3_2_90b_vision_instruct.value,
),
build_model_alias(
"fireworks/llama-guard-3-8b",
CoreModelId.llama_guard_3_8b.value,
),
build_model_alias(
"fireworks/llama-guard-3-11b-vision",
CoreModelId.llama_guard_3_11b_vision.value,
),
]
class FireworksInferenceAdapter(
ModelRegistryHelper, Inference, NeedsRequestProviderData
):
def __init__(self, config: FireworksImplConfig) -> None:
ModelRegistryHelper.__init__(
self, stack_to_provider_models_map=FIREWORKS_SUPPORTED_MODELS
)
ModelRegistryHelper.__init__(self, model_aliases)
self.config = config
self.formatter = ChatFormat(Tokenizer.get_instance())
@ -74,15 +104,16 @@ class FireworksInferenceAdapter(
async def completion(
self,
model: str,
model_id: str,
content: InterleavedTextMedia,
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
model = await self.model_store.get_model(model_id)
request = CompletionRequest(
model=model,
model=model.provider_resource_id,
content=content,
sampling_params=sampling_params,
response_format=response_format,
@ -138,7 +169,7 @@ class FireworksInferenceAdapter(
async def chat_completion(
self,
model: str,
model_id: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
tools: Optional[List[ToolDefinition]] = None,
@ -148,8 +179,9 @@ class FireworksInferenceAdapter(
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
model = await self.model_store.get_model(model_id)
request = ChatCompletionRequest(
model=model,
model=model.provider_resource_id,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
@ -207,7 +239,7 @@ class FireworksInferenceAdapter(
]
else:
input_dict["prompt"] = chat_completion_request_to_prompt(
request, self.formatter
request, self.get_llama_model(request.model), self.formatter
)
else:
assert (
@ -221,7 +253,7 @@ class FireworksInferenceAdapter(
input_dict["prompt"] = input_dict["prompt"][len("<|begin_of_text|>") :]
return {
"model": self.map_to_provider_model(request.model),
"model": request.model,
**input_dict,
"stream": request.stream,
**self._build_options(request.sampling_params, request.response_format),
@ -229,7 +261,7 @@ class FireworksInferenceAdapter(
async def embeddings(
self,
model: str,
model_id: str,
contents: List[InterleavedTextMedia],
) -> EmbeddingsResponse:
raise NotImplementedError()

View file

@ -7,15 +7,20 @@
from typing import AsyncGenerator
import httpx
from llama_models.datatypes import CoreModelId
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import Message
from llama_models.llama3.api.tokenizer import Tokenizer
from ollama import AsyncClient
from llama_stack.providers.utils.inference.model_registry import (
build_model_alias,
ModelRegistryHelper,
)
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.utils.inference.openai_compat import (
get_sampling_options,
@ -33,19 +38,45 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
request_has_media,
)
OLLAMA_SUPPORTED_MODELS = {
"Llama3.1-8B-Instruct": "llama3.1:8b-instruct-fp16",
"Llama3.1-70B-Instruct": "llama3.1:70b-instruct-fp16",
"Llama3.2-1B-Instruct": "llama3.2:1b-instruct-fp16",
"Llama3.2-3B-Instruct": "llama3.2:3b-instruct-fp16",
"Llama-Guard-3-8B": "llama-guard3:8b",
"Llama-Guard-3-1B": "llama-guard3:1b",
"Llama3.2-11B-Vision-Instruct": "x/llama3.2-vision:11b-instruct-fp16",
}
model_aliases = [
build_model_alias(
"llama3.1:8b-instruct-fp16",
CoreModelId.llama3_1_8b_instruct.value,
),
build_model_alias(
"llama3.1:70b-instruct-fp16",
CoreModelId.llama3_1_70b_instruct.value,
),
build_model_alias(
"llama3.2:1b-instruct-fp16",
CoreModelId.llama3_2_1b_instruct.value,
),
build_model_alias(
"llama3.2:3b-instruct-fp16",
CoreModelId.llama3_2_3b_instruct.value,
),
build_model_alias(
"llama-guard3:8b",
CoreModelId.llama_guard_3_8b.value,
),
build_model_alias(
"llama-guard3:1b",
CoreModelId.llama_guard_3_1b.value,
),
build_model_alias(
"x/llama3.2-vision:11b-instruct-fp16",
CoreModelId.llama3_2_11b_vision_instruct.value,
),
]
class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
class OllamaInferenceAdapter(Inference, ModelRegistryHelper, ModelsProtocolPrivate):
def __init__(self, url: str) -> None:
ModelRegistryHelper.__init__(
self,
model_aliases=model_aliases,
)
self.url = url
self.formatter = ChatFormat(Tokenizer.get_instance())
@ -65,44 +96,18 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
async def shutdown(self) -> None:
pass
async def register_model(self, model: Model) -> None:
if model.identifier not in OLLAMA_SUPPORTED_MODELS:
raise ValueError(f"Model {model.identifier} is not supported by Ollama")
async def list_models(self) -> List[Model]:
ollama_to_llama = {v: k for k, v in OLLAMA_SUPPORTED_MODELS.items()}
ret = []
res = await self.client.ps()
for r in res["models"]:
if r["model"] not in ollama_to_llama:
print(f"Ollama is running a model unknown to Llama Stack: {r['model']}")
continue
llama_model = ollama_to_llama[r["model"]]
print(f"Found model {llama_model} in Ollama")
ret.append(
Model(
identifier=llama_model,
metadata={
"ollama_model": r["model"],
},
)
)
return ret
async def completion(
self,
model: str,
model_id: str,
content: InterleavedTextMedia,
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
model = await self.model_store.get_model(model_id)
request = CompletionRequest(
model=model,
model=model.provider_resource_id,
content=content,
sampling_params=sampling_params,
stream=stream,
@ -148,7 +153,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
async def chat_completion(
self,
model: str,
model_id: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
@ -158,8 +163,10 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
model = await self.model_store.get_model(model_id)
print(f"model={model}")
request = ChatCompletionRequest(
model=model,
model=model.provider_resource_id,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
@ -197,7 +204,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
else:
input_dict["raw"] = True
input_dict["prompt"] = chat_completion_request_to_prompt(
request, self.formatter
request, self.get_llama_model(request.model), self.formatter
)
else:
assert (
@ -207,7 +214,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
input_dict["raw"] = True
return {
"model": OLLAMA_SUPPORTED_MODELS[request.model],
"model": request.model,
**input_dict,
"options": sampling_options,
"stream": request.stream,
@ -271,7 +278,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
async def embeddings(
self,
model: str,
model_id: str,
contents: List[InterleavedTextMedia],
) -> EmbeddingsResponse:
raise NotImplementedError()

View file

@ -6,6 +6,8 @@
from typing import AsyncGenerator
from llama_models.datatypes import CoreModelId
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import Message
@ -15,7 +17,10 @@ from together import Together
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.providers.utils.inference.model_registry import (
build_model_alias,
ModelRegistryHelper,
)
from llama_stack.providers.utils.inference.openai_compat import (
get_sampling_options,
process_chat_completion_response,
@ -33,25 +38,47 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
from .config import TogetherImplConfig
TOGETHER_SUPPORTED_MODELS = {
"Llama3.1-8B-Instruct": "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo",
"Llama3.1-70B-Instruct": "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
"Llama3.1-405B-Instruct": "meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo",
"Llama3.2-3B-Instruct": "meta-llama/Llama-3.2-3B-Instruct-Turbo",
"Llama3.2-11B-Vision-Instruct": "meta-llama/Llama-3.2-11B-Vision-Instruct-Turbo",
"Llama3.2-90B-Vision-Instruct": "meta-llama/Llama-3.2-90B-Vision-Instruct-Turbo",
"Llama-Guard-3-8B": "meta-llama/Meta-Llama-Guard-3-8B",
"Llama-Guard-3-11B-Vision": "meta-llama/Llama-Guard-3-11B-Vision-Turbo",
}
model_aliases = [
build_model_alias(
"meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo",
CoreModelId.llama3_1_8b_instruct.value,
),
build_model_alias(
"meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
CoreModelId.llama3_1_70b_instruct.value,
),
build_model_alias(
"meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo",
CoreModelId.llama3_1_405b_instruct.value,
),
build_model_alias(
"meta-llama/Llama-3.2-3B-Instruct-Turbo",
CoreModelId.llama3_2_3b_instruct.value,
),
build_model_alias(
"meta-llama/Llama-3.2-11B-Vision-Instruct-Turbo",
CoreModelId.llama3_2_11b_vision_instruct.value,
),
build_model_alias(
"meta-llama/Llama-3.2-90B-Vision-Instruct-Turbo",
CoreModelId.llama3_2_90b_vision_instruct.value,
),
build_model_alias(
"meta-llama/Meta-Llama-Guard-3-8B",
CoreModelId.llama_guard_3_8b.value,
),
build_model_alias(
"meta-llama/Llama-Guard-3-11B-Vision-Turbo",
CoreModelId.llama_guard_3_11b_vision.value,
),
]
class TogetherInferenceAdapter(
ModelRegistryHelper, Inference, NeedsRequestProviderData
):
def __init__(self, config: TogetherImplConfig) -> None:
ModelRegistryHelper.__init__(
self, stack_to_provider_models_map=TOGETHER_SUPPORTED_MODELS
)
ModelRegistryHelper.__init__(self, model_aliases)
self.config = config
self.formatter = ChatFormat(Tokenizer.get_instance())
@ -63,15 +90,16 @@ class TogetherInferenceAdapter(
async def completion(
self,
model: str,
model_id: str,
content: InterleavedTextMedia,
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
model = await self.model_store.get_model(model_id)
request = CompletionRequest(
model=model,
model=model.provider_resource_id,
content=content,
sampling_params=sampling_params,
response_format=response_format,
@ -135,7 +163,7 @@ class TogetherInferenceAdapter(
async def chat_completion(
self,
model: str,
model_id: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
tools: Optional[List[ToolDefinition]] = None,
@ -145,8 +173,9 @@ class TogetherInferenceAdapter(
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
model = await self.model_store.get_model(model_id)
request = ChatCompletionRequest(
model=model,
model=model.provider_resource_id,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
@ -204,7 +233,7 @@ class TogetherInferenceAdapter(
]
else:
input_dict["prompt"] = chat_completion_request_to_prompt(
request, self.formatter
request, self.get_llama_model(request.model), self.formatter
)
else:
assert (
@ -213,7 +242,7 @@ class TogetherInferenceAdapter(
input_dict["prompt"] = completion_request_to_prompt(request, self.formatter)
return {
"model": self.map_to_provider_model(request.model),
"model": request.model,
**input_dict,
"stream": request.stream,
**self._build_options(request.sampling_params, request.response_format),
@ -221,7 +250,7 @@ class TogetherInferenceAdapter(
async def embeddings(
self,
model: str,
model_id: str,
contents: List[InterleavedTextMedia],
) -> EmbeddingsResponse:
raise NotImplementedError()

View file

@ -8,13 +8,17 @@ from typing import AsyncGenerator
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import Message
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_models.sku_list import all_registered_models, resolve_model
from llama_models.sku_list import all_registered_models
from openai import OpenAI
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.utils.inference.model_registry import (
build_model_alias,
ModelRegistryHelper,
)
from llama_stack.providers.utils.inference.openai_compat import (
get_sampling_options,
process_chat_completion_response,
@ -30,44 +34,36 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
from .config import VLLMInferenceAdapterConfig
class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
def build_model_aliases():
return [
build_model_alias(
model.huggingface_repo,
model.descriptor(),
)
for model in all_registered_models()
if model.huggingface_repo
]
class VLLMInferenceAdapter(Inference, ModelRegistryHelper, ModelsProtocolPrivate):
def __init__(self, config: VLLMInferenceAdapterConfig) -> None:
ModelRegistryHelper.__init__(
self,
model_aliases=build_model_aliases(),
)
self.config = config
self.formatter = ChatFormat(Tokenizer.get_instance())
self.client = None
self.huggingface_repo_to_llama_model_id = {
model.huggingface_repo: model.descriptor()
for model in all_registered_models()
if model.huggingface_repo
}
async def initialize(self) -> None:
self.client = OpenAI(base_url=self.config.url, api_key=self.config.api_token)
async def register_model(self, model: Model) -> None:
for running_model in self.client.models.list():
repo = running_model.id
if repo not in self.huggingface_repo_to_llama_model_id:
print(f"Unknown model served by vllm: {repo}")
continue
identifier = self.huggingface_repo_to_llama_model_id[repo]
if identifier == model.provider_resource_id:
print(
f"Verified that model {model.provider_resource_id} is being served by vLLM"
)
return
raise ValueError(
f"Model {model.provider_resource_id} is not being served by vLLM"
)
async def shutdown(self) -> None:
pass
async def completion(
self,
model: str,
model_id: str,
content: InterleavedTextMedia,
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
@ -78,7 +74,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
async def chat_completion(
self,
model: str,
model_id: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
@ -88,8 +84,9 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
model = await self.model_store.get_model(model_id)
request = ChatCompletionRequest(
model=model,
model=model.provider_resource_id,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
@ -141,10 +138,6 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
if "max_tokens" not in options:
options["max_tokens"] = self.config.max_tokens
model = resolve_model(request.model)
if model is None:
raise ValueError(f"Unknown model: {request.model}")
input_dict = {}
media_present = request_has_media(request)
if isinstance(request, ChatCompletionRequest):
@ -156,16 +149,20 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
]
else:
input_dict["prompt"] = chat_completion_request_to_prompt(
request, self.formatter
request, self.get_llama_model(request.model), self.formatter
)
else:
assert (
not media_present
), "Together does not support media for Completion requests"
input_dict["prompt"] = completion_request_to_prompt(request, self.formatter)
input_dict["prompt"] = completion_request_to_prompt(
request,
self.get_llama_model(request.model),
self.formatter,
)
return {
"model": model.huggingface_repo,
"model": request.model,
**input_dict,
"stream": request.stream,
**options,
@ -173,7 +170,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
async def embeddings(
self,
model: str,
model_id: str,
contents: List[InterleavedTextMedia],
) -> EmbeddingsResponse:
raise NotImplementedError()

View file

@ -49,7 +49,7 @@ def inference_meta_reference(inference_model) -> ProviderFixture:
providers=[
Provider(
provider_id=f"meta-reference-{i}",
provider_type="meta-reference",
provider_type="inline::meta-reference",
config=MetaReferenceInferenceConfig(
model=m,
max_seq_len=4096,
@ -142,6 +142,31 @@ def inference_bedrock() -> ProviderFixture:
)
def get_model_short_name(model_name: str) -> str:
"""Convert model name to a short test identifier.
Args:
model_name: Full model name like "Llama3.1-8B-Instruct"
Returns:
Short name like "llama_8b" suitable for test markers
"""
model_name = model_name.lower()
if "vision" in model_name:
return "llama_vision"
elif "3b" in model_name:
return "llama_3b"
elif "8b" in model_name:
return "llama_8b"
else:
return model_name.replace(".", "_").replace("-", "_")
@pytest.fixture(scope="session")
def model_id(inference_model) -> str:
return get_model_short_name(inference_model)
INFERENCE_FIXTURES = [
"meta_reference",
"ollama",

View file

@ -96,7 +96,7 @@ class TestInference:
response = await inference_impl.completion(
content="Micheael Jordan is born in ",
stream=False,
model=inference_model,
model_id=inference_model,
sampling_params=SamplingParams(
max_tokens=50,
),
@ -110,7 +110,7 @@ class TestInference:
async for r in await inference_impl.completion(
content="Roses are red,",
stream=True,
model=inference_model,
model_id=inference_model,
sampling_params=SamplingParams(
max_tokens=50,
),
@ -171,7 +171,7 @@ class TestInference:
):
inference_impl, _ = inference_stack
response = await inference_impl.chat_completion(
model=inference_model,
model_id=inference_model,
messages=sample_messages,
stream=False,
**common_params,
@ -204,7 +204,7 @@ class TestInference:
num_seasons_in_nba: int
response = await inference_impl.chat_completion(
model=inference_model,
model_id=inference_model,
messages=[
SystemMessage(content="You are a helpful assistant."),
UserMessage(content="Please give me information about Michael Jordan."),
@ -227,7 +227,7 @@ class TestInference:
assert answer.num_seasons_in_nba == 15
response = await inference_impl.chat_completion(
model=inference_model,
model_id=inference_model,
messages=[
SystemMessage(content="You are a helpful assistant."),
UserMessage(content="Please give me information about Michael Jordan."),
@ -250,7 +250,7 @@ class TestInference:
response = [
r
async for r in await inference_impl.chat_completion(
model=inference_model,
model_id=inference_model,
messages=sample_messages,
stream=True,
**common_params,
@ -286,7 +286,7 @@ class TestInference:
]
response = await inference_impl.chat_completion(
model=inference_model,
model_id=inference_model,
messages=messages,
tools=[sample_tool_definition],
stream=False,
@ -327,7 +327,7 @@ class TestInference:
response = [
r
async for r in await inference_impl.chat_completion(
model=inference_model,
model_id=inference_model,
messages=messages,
tools=[sample_tool_definition],
stream=True,

View file

@ -4,32 +4,61 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Dict
from collections import namedtuple
from typing import List, Optional
from llama_models.sku_list import resolve_model
from llama_models.sku_list import all_registered_models
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
ModelAlias = namedtuple("ModelAlias", ["provider_model_id", "aliases", "llama_model"])
def get_huggingface_repo(model_descriptor: str) -> Optional[str]:
for model in all_registered_models():
if model.descriptor() == model_descriptor:
return model.huggingface_repo
return None
def build_model_alias(provider_model_id: str, model_descriptor: str) -> ModelAlias:
return ModelAlias(
provider_model_id=provider_model_id,
aliases=[
model_descriptor,
get_huggingface_repo(model_descriptor),
],
llama_model=model_descriptor,
)
class ModelRegistryHelper(ModelsProtocolPrivate):
def __init__(self, model_aliases: List[ModelAlias]):
self.alias_to_provider_id_map = {}
self.provider_id_to_llama_model_map = {}
for alias_obj in model_aliases:
for alias in alias_obj.aliases:
self.alias_to_provider_id_map[alias] = alias_obj.provider_model_id
# also add a mapping from provider model id to itself for easy lookup
self.alias_to_provider_id_map[alias_obj.provider_model_id] = (
alias_obj.provider_model_id
)
self.provider_id_to_llama_model_map[alias_obj.provider_model_id] = (
alias_obj.llama_model
)
def __init__(self, stack_to_provider_models_map: Dict[str, str]):
self.stack_to_provider_models_map = stack_to_provider_models_map
def map_to_provider_model(self, identifier: str) -> str:
model = resolve_model(identifier)
if not model:
def get_provider_model_id(self, identifier: str) -> str:
if identifier in self.alias_to_provider_id_map:
return self.alias_to_provider_id_map[identifier]
else:
raise ValueError(f"Unknown model: `{identifier}`")
if identifier not in self.stack_to_provider_models_map:
raise ValueError(
f"Model {identifier} not found in map {self.stack_to_provider_models_map}"
)
def get_llama_model(self, provider_model_id: str) -> str:
return self.provider_id_to_llama_model_map[provider_model_id]
return self.stack_to_provider_models_map[identifier]
async def register_model(self, model: Model) -> Model:
model.provider_resource_id = self.get_provider_model_id(
model.provider_resource_id
)
async def register_model(self, model: Model) -> None:
if model.identifier not in self.stack_to_provider_models_map:
raise ValueError(
f"Unsupported model {model.identifier}. Supported models: {self.stack_to_provider_models_map.keys()}"
)
return model

View file

@ -147,17 +147,17 @@ def augment_content_with_response_format_prompt(response_format, content):
def chat_completion_request_to_prompt(
request: ChatCompletionRequest, formatter: ChatFormat
request: ChatCompletionRequest, llama_model: str, formatter: ChatFormat
) -> str:
messages = chat_completion_request_to_messages(request)
messages = chat_completion_request_to_messages(request, llama_model)
model_input = formatter.encode_dialog_prompt(messages)
return formatter.tokenizer.decode(model_input.tokens)
def chat_completion_request_to_model_input_info(
request: ChatCompletionRequest, formatter: ChatFormat
request: ChatCompletionRequest, llama_model: str, formatter: ChatFormat
) -> Tuple[str, int]:
messages = chat_completion_request_to_messages(request)
messages = chat_completion_request_to_messages(request, llama_model)
model_input = formatter.encode_dialog_prompt(messages)
return (
formatter.tokenizer.decode(model_input.tokens),
@ -167,14 +167,15 @@ def chat_completion_request_to_model_input_info(
def chat_completion_request_to_messages(
request: ChatCompletionRequest,
llama_model: str,
) -> List[Message]:
"""Reads chat completion request and augments the messages to handle tools.
For eg. for llama_3_1, add system message with the appropriate tools or
add user messsage for custom tools, etc.
"""
model = resolve_model(request.model)
model = resolve_model(llama_model)
if model is None:
cprint(f"Could not resolve model {request.model}", color="red")
cprint(f"Could not resolve model {llama_model}", color="red")
return request.messages
if model.descriptor() not in supported_inference_models():