mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 09:53:45 +00:00
This PR enables routing of fully qualified model IDs of the form `provider_id/model_id` even when the models are not registered with the Stack. Here's the situation: assume a remote inference provider which works only when users provide their own API keys via `X-LlamaStack-Provider-Data` header. By definition, we cannot list models and hence update our routing registry. But because we _require_ a provider ID in the models now, we can identify which provider to route to and let that provider decide. Note that we still try to look up our registry since it may have a pre-registered alias. Just that we don't outright fail when we are not able to look it up. Also, updated inference router so that the responses have the _exact_ model that the request had. ## Test Plan Added an integration test Closes #3929<hr>This is an automatic backport of pull request #3928 done by [Mergify](https://mergify.com). --------- Co-authored-by: Ashwin Bharambe <ashwin.bharambe@gmail.com> Co-authored-by: ehhuang <ehhuang@users.noreply.github.com>
This commit is contained in:
parent
a6c3a9cadf
commit
641d5144be
6 changed files with 214 additions and 55 deletions
|
|
@ -105,7 +105,8 @@ class InferenceRouter(Inference):
|
|||
prompt_tokens: int,
|
||||
completion_tokens: int,
|
||||
total_tokens: int,
|
||||
model: Model,
|
||||
fully_qualified_model_id: str,
|
||||
provider_id: str,
|
||||
) -> list[MetricEvent]:
|
||||
"""Constructs a list of MetricEvent objects containing token usage metrics.
|
||||
|
||||
|
|
@ -113,7 +114,8 @@ class InferenceRouter(Inference):
|
|||
prompt_tokens: Number of tokens in the prompt
|
||||
completion_tokens: Number of tokens in the completion
|
||||
total_tokens: Total number of tokens used
|
||||
model: Model object containing model_id and provider_id
|
||||
fully_qualified_model_id:
|
||||
provider_id: The provider identifier
|
||||
|
||||
Returns:
|
||||
List of MetricEvent objects with token usage metrics
|
||||
|
|
@ -139,8 +141,8 @@ class InferenceRouter(Inference):
|
|||
timestamp=datetime.now(UTC),
|
||||
unit="tokens",
|
||||
attributes={
|
||||
"model_id": model.model_id,
|
||||
"provider_id": model.provider_id,
|
||||
"model_id": fully_qualified_model_id,
|
||||
"provider_id": provider_id,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
|
@ -153,7 +155,9 @@ class InferenceRouter(Inference):
|
|||
total_tokens: int,
|
||||
model: Model,
|
||||
) -> list[MetricInResponse]:
|
||||
metrics = self._construct_metrics(prompt_tokens, completion_tokens, total_tokens, model)
|
||||
metrics = self._construct_metrics(
|
||||
prompt_tokens, completion_tokens, total_tokens, model.model_id, model.provider_id
|
||||
)
|
||||
if self.telemetry:
|
||||
for metric in metrics:
|
||||
enqueue_event(metric)
|
||||
|
|
@ -173,14 +177,25 @@ class InferenceRouter(Inference):
|
|||
encoded = self.formatter.encode_content(messages)
|
||||
return len(encoded.tokens) if encoded and encoded.tokens else 0
|
||||
|
||||
async def _get_model(self, model_id: str, expected_model_type: str) -> Model:
|
||||
"""takes a model id and gets model after ensuring that it is accessible and of the correct type"""
|
||||
model = await self.routing_table.get_model(model_id)
|
||||
if model is None:
|
||||
async def _get_model_provider(self, model_id: str, expected_model_type: str) -> tuple[Inference, str]:
|
||||
model = await self.routing_table.get_object_by_identifier("model", model_id)
|
||||
if model:
|
||||
if model.model_type != expected_model_type:
|
||||
raise ModelTypeError(model_id, model.model_type, expected_model_type)
|
||||
|
||||
provider = await self.routing_table.get_provider_impl(model.identifier)
|
||||
return provider, model.provider_resource_id
|
||||
|
||||
splits = model_id.split("/", maxsplit=1)
|
||||
if len(splits) != 2:
|
||||
raise ModelNotFoundError(model_id)
|
||||
if model.model_type != expected_model_type:
|
||||
raise ModelTypeError(model_id, model.model_type, expected_model_type)
|
||||
return model
|
||||
|
||||
provider_id, provider_resource_id = splits
|
||||
if provider_id not in self.routing_table.impls_by_provider_id:
|
||||
logger.warning(f"Provider {provider_id} not found for model {model_id}")
|
||||
raise ModelNotFoundError(model_id)
|
||||
|
||||
return self.routing_table.impls_by_provider_id[provider_id], provider_resource_id
|
||||
|
||||
async def openai_completion(
|
||||
self,
|
||||
|
|
@ -189,24 +204,24 @@ class InferenceRouter(Inference):
|
|||
logger.debug(
|
||||
f"InferenceRouter.openai_completion: model={params.model}, stream={params.stream}, prompt={params.prompt}",
|
||||
)
|
||||
model_obj = await self._get_model(params.model, ModelType.llm)
|
||||
request_model_id = params.model
|
||||
provider, provider_resource_id = await self._get_model_provider(params.model, ModelType.llm)
|
||||
params.model = provider_resource_id
|
||||
|
||||
# Update params with the resolved model identifier
|
||||
params.model = model_obj.identifier
|
||||
|
||||
provider = await self.routing_table.get_provider_impl(model_obj.identifier)
|
||||
if params.stream:
|
||||
return await provider.openai_completion(params)
|
||||
# TODO: Metrics do NOT work with openai_completion stream=True due to the fact
|
||||
# that we do not return an AsyncIterator, our tests expect a stream of chunks we cannot intercept currently.
|
||||
|
||||
response = await provider.openai_completion(params)
|
||||
response.model = request_model_id
|
||||
if self.telemetry:
|
||||
metrics = self._construct_metrics(
|
||||
prompt_tokens=response.usage.prompt_tokens,
|
||||
completion_tokens=response.usage.completion_tokens,
|
||||
total_tokens=response.usage.total_tokens,
|
||||
model=model_obj,
|
||||
fully_qualified_model_id=request_model_id,
|
||||
provider_id=provider.__provider_id__,
|
||||
)
|
||||
for metric in metrics:
|
||||
enqueue_event(metric)
|
||||
|
|
@ -224,7 +239,9 @@ class InferenceRouter(Inference):
|
|||
logger.debug(
|
||||
f"InferenceRouter.openai_chat_completion: model={params.model}, stream={params.stream}, messages={params.messages}",
|
||||
)
|
||||
model_obj = await self._get_model(params.model, ModelType.llm)
|
||||
request_model_id = params.model
|
||||
provider, provider_resource_id = await self._get_model_provider(params.model, ModelType.llm)
|
||||
params.model = provider_resource_id
|
||||
|
||||
# Use the OpenAI client for a bit of extra input validation without
|
||||
# exposing the OpenAI client itself as part of our API surface
|
||||
|
|
@ -242,10 +259,6 @@ class InferenceRouter(Inference):
|
|||
params.tool_choice = None
|
||||
params.tools = None
|
||||
|
||||
# Update params with the resolved model identifier
|
||||
params.model = model_obj.identifier
|
||||
|
||||
provider = await self.routing_table.get_provider_impl(model_obj.identifier)
|
||||
if params.stream:
|
||||
response_stream = await provider.openai_chat_completion(params)
|
||||
|
||||
|
|
@ -253,11 +266,13 @@ class InferenceRouter(Inference):
|
|||
# We need to add metrics to each chunk and store the final completion
|
||||
return self.stream_tokens_and_compute_metrics_openai_chat(
|
||||
response=response_stream,
|
||||
model=model_obj,
|
||||
fully_qualified_model_id=request_model_id,
|
||||
provider_id=provider.__provider_id__,
|
||||
messages=params.messages,
|
||||
)
|
||||
|
||||
response = await self._nonstream_openai_chat_completion(provider, params)
|
||||
response.model = request_model_id
|
||||
|
||||
# Store the response with the ID that will be returned to the client
|
||||
if self.store:
|
||||
|
|
@ -268,7 +283,8 @@ class InferenceRouter(Inference):
|
|||
prompt_tokens=response.usage.prompt_tokens,
|
||||
completion_tokens=response.usage.completion_tokens,
|
||||
total_tokens=response.usage.total_tokens,
|
||||
model=model_obj,
|
||||
fully_qualified_model_id=request_model_id,
|
||||
provider_id=provider.__provider_id__,
|
||||
)
|
||||
for metric in metrics:
|
||||
enqueue_event(metric)
|
||||
|
|
@ -285,13 +301,13 @@ class InferenceRouter(Inference):
|
|||
logger.debug(
|
||||
f"InferenceRouter.openai_embeddings: model={params.model}, input_type={type(params.input)}, encoding_format={params.encoding_format}, dimensions={params.dimensions}",
|
||||
)
|
||||
model_obj = await self._get_model(params.model, ModelType.embedding)
|
||||
request_model_id = params.model
|
||||
provider, provider_resource_id = await self._get_model_provider(params.model, ModelType.embedding)
|
||||
params.model = provider_resource_id
|
||||
|
||||
# Update model to use resolved identifier
|
||||
params.model = model_obj.identifier
|
||||
|
||||
provider = await self.routing_table.get_provider_impl(model_obj.identifier)
|
||||
return await provider.openai_embeddings(params)
|
||||
response = await provider.openai_embeddings(params)
|
||||
response.model = request_model_id
|
||||
return response
|
||||
|
||||
async def list_chat_completions(
|
||||
self,
|
||||
|
|
@ -347,7 +363,8 @@ class InferenceRouter(Inference):
|
|||
self,
|
||||
response,
|
||||
prompt_tokens,
|
||||
model,
|
||||
fully_qualified_model_id: str,
|
||||
provider_id: str,
|
||||
tool_prompt_format: ToolPromptFormat | None = None,
|
||||
) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None] | AsyncGenerator[CompletionResponseStreamChunk, None]:
|
||||
completion_text = ""
|
||||
|
|
@ -385,7 +402,8 @@ class InferenceRouter(Inference):
|
|||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=total_tokens,
|
||||
model=model,
|
||||
fully_qualified_model_id=fully_qualified_model_id,
|
||||
provider_id=provider_id,
|
||||
)
|
||||
for metric in completion_metrics:
|
||||
if metric.metric in [
|
||||
|
|
@ -405,7 +423,8 @@ class InferenceRouter(Inference):
|
|||
prompt_tokens or 0,
|
||||
completion_tokens or 0,
|
||||
total_tokens,
|
||||
model,
|
||||
fully_qualified_model_id=fully_qualified_model_id,
|
||||
provider_id=provider_id,
|
||||
)
|
||||
async_metrics = [
|
||||
MetricInResponse(metric=metric.metric, value=metric.value) for metric in completion_metrics
|
||||
|
|
@ -417,7 +436,8 @@ class InferenceRouter(Inference):
|
|||
self,
|
||||
response: ChatCompletionResponse | CompletionResponse,
|
||||
prompt_tokens,
|
||||
model,
|
||||
fully_qualified_model_id: str,
|
||||
provider_id: str,
|
||||
tool_prompt_format: ToolPromptFormat | None = None,
|
||||
):
|
||||
if isinstance(response, ChatCompletionResponse):
|
||||
|
|
@ -434,7 +454,8 @@ class InferenceRouter(Inference):
|
|||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=total_tokens,
|
||||
model=model,
|
||||
fully_qualified_model_id=fully_qualified_model_id,
|
||||
provider_id=provider_id,
|
||||
)
|
||||
for metric in completion_metrics:
|
||||
if metric.metric in ["completion_tokens", "total_tokens"]: # Only log completion and total tokens
|
||||
|
|
@ -448,14 +469,16 @@ class InferenceRouter(Inference):
|
|||
prompt_tokens or 0,
|
||||
completion_tokens or 0,
|
||||
total_tokens,
|
||||
model,
|
||||
fully_qualified_model_id=fully_qualified_model_id,
|
||||
provider_id=provider_id,
|
||||
)
|
||||
return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in metrics]
|
||||
|
||||
async def stream_tokens_and_compute_metrics_openai_chat(
|
||||
self,
|
||||
response: AsyncIterator[OpenAIChatCompletionChunk],
|
||||
model: Model,
|
||||
fully_qualified_model_id: str,
|
||||
provider_id: str,
|
||||
messages: list[OpenAIMessageParam] | None = None,
|
||||
) -> AsyncIterator[OpenAIChatCompletionChunk]:
|
||||
"""Stream OpenAI chat completion chunks, compute metrics, and store the final completion."""
|
||||
|
|
@ -475,6 +498,8 @@ class InferenceRouter(Inference):
|
|||
if created is None and chunk.created:
|
||||
created = chunk.created
|
||||
|
||||
chunk.model = fully_qualified_model_id
|
||||
|
||||
# Accumulate choice data for final assembly
|
||||
if chunk.choices:
|
||||
for choice_delta in chunk.choices:
|
||||
|
|
@ -531,7 +556,8 @@ class InferenceRouter(Inference):
|
|||
prompt_tokens=chunk.usage.prompt_tokens,
|
||||
completion_tokens=chunk.usage.completion_tokens,
|
||||
total_tokens=chunk.usage.total_tokens,
|
||||
model=model,
|
||||
model_id=fully_qualified_model_id,
|
||||
provider_id=provider_id,
|
||||
)
|
||||
for metric in metrics:
|
||||
enqueue_event(metric)
|
||||
|
|
@ -579,7 +605,7 @@ class InferenceRouter(Inference):
|
|||
id=id,
|
||||
choices=assembled_choices,
|
||||
created=created or int(time.time()),
|
||||
model=model.identifier,
|
||||
model=fully_qualified_model_id,
|
||||
object="chat.completion",
|
||||
)
|
||||
logger.debug(f"InferenceRouter.completion_response: {final_response}")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue