This commit is contained in:
Omar Abdelwahab 2025-10-03 01:40:16 +00:00 committed by GitHub
commit 656addb45d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 302 additions and 137 deletions

View file

@ -6,16 +6,7 @@
from collections.abc import AsyncIterator from collections.abc import AsyncIterator
from enum import Enum from enum import Enum
from typing import ( from typing import Annotated, Any, Literal, Protocol, runtime_checkable
Annotated,
Any,
Literal,
Protocol,
runtime_checkable,
)
from pydantic import BaseModel, Field, field_validator
from typing_extensions import TypedDict
from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent
from llama_stack.apis.common.responses import Order from llama_stack.apis.common.responses import Order
@ -32,6 +23,9 @@ from llama_stack.models.llama.datatypes import (
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
from pydantic import BaseModel, Field, field_validator
from typing_extensions import TypedDict
register_schema(ToolCall) register_schema(ToolCall)
register_schema(ToolDefinition) register_schema(ToolDefinition)
@ -357,32 +351,32 @@ class CompletionRequest(BaseModel):
logprobs: LogProbConfig | None = None logprobs: LogProbConfig | None = None
@json_schema_type # @json_schema_type
class CompletionResponse(MetricResponseMixin): # class CompletionResponse(MetricResponseMixin):
"""Response from a completion request. # """Response from a completion request.
:param content: The generated completion text # :param content: The generated completion text
:param stop_reason: Reason why generation stopped # :param stop_reason: Reason why generation stopped
:param logprobs: Optional log probabilities for generated tokens # :param logprobs: Optional log probabilities for generated tokens
""" # """
content: str # content: str
stop_reason: StopReason # stop_reason: StopReason
logprobs: list[TokenLogProbs] | None = None # logprobs: list[TokenLogProbs] | None = None
@json_schema_type # @json_schema_type
class CompletionResponseStreamChunk(MetricResponseMixin): # class CompletionResponseStreamChunk(MetricResponseMixin):
"""A chunk of a streamed completion response. # """A chunk of a streamed completion response.
:param delta: New content generated since last chunk. This can be one or more tokens. # :param delta: New content generated since last chunk. This can be one or more tokens.
:param stop_reason: Optional reason why generation stopped, if complete # :param stop_reason: Optional reason why generation stopped, if complete
:param logprobs: Optional log probabilities for generated tokens # :param logprobs: Optional log probabilities for generated tokens
""" # """
delta: str # delta: str
stop_reason: StopReason | None = None # stop_reason: StopReason | None = None
logprobs: list[TokenLogProbs] | None = None # logprobs: list[TokenLogProbs] | None = None
class SystemMessageBehavior(Enum): class SystemMessageBehavior(Enum):
@ -415,7 +409,9 @@ class ToolConfig(BaseModel):
tool_choice: ToolChoice | str | None = Field(default=ToolChoice.auto) tool_choice: ToolChoice | str | None = Field(default=ToolChoice.auto)
tool_prompt_format: ToolPromptFormat | None = Field(default=None) tool_prompt_format: ToolPromptFormat | None = Field(default=None)
system_message_behavior: SystemMessageBehavior | None = Field(default=SystemMessageBehavior.append) system_message_behavior: SystemMessageBehavior | None = Field(
default=SystemMessageBehavior.append
)
def model_post_init(self, __context: Any) -> None: def model_post_init(self, __context: Any) -> None:
if isinstance(self.tool_choice, str): if isinstance(self.tool_choice, str):
@ -544,15 +540,21 @@ class OpenAIFile(BaseModel):
OpenAIChatCompletionContentPartParam = Annotated[ OpenAIChatCompletionContentPartParam = Annotated[
OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam | OpenAIFile, OpenAIChatCompletionContentPartTextParam
| OpenAIChatCompletionContentPartImageParam
| OpenAIFile,
Field(discriminator="type"), Field(discriminator="type"),
] ]
register_schema(OpenAIChatCompletionContentPartParam, name="OpenAIChatCompletionContentPartParam") register_schema(
OpenAIChatCompletionContentPartParam, name="OpenAIChatCompletionContentPartParam"
)
OpenAIChatCompletionMessageContent = str | list[OpenAIChatCompletionContentPartParam] OpenAIChatCompletionMessageContent = str | list[OpenAIChatCompletionContentPartParam]
OpenAIChatCompletionTextOnlyMessageContent = str | list[OpenAIChatCompletionContentPartTextParam] OpenAIChatCompletionTextOnlyMessageContent = (
str | list[OpenAIChatCompletionContentPartTextParam]
)
@json_schema_type @json_schema_type
@ -720,7 +722,9 @@ class OpenAIResponseFormatJSONObject(BaseModel):
OpenAIResponseFormatParam = Annotated[ OpenAIResponseFormatParam = Annotated[
OpenAIResponseFormatText | OpenAIResponseFormatJSONSchema | OpenAIResponseFormatJSONObject, OpenAIResponseFormatText
| OpenAIResponseFormatJSONSchema
| OpenAIResponseFormatJSONObject,
Field(discriminator="type"), Field(discriminator="type"),
] ]
register_schema(OpenAIResponseFormatParam, name="OpenAIResponseFormatParam") register_schema(OpenAIResponseFormatParam, name="OpenAIResponseFormatParam")
@ -1049,8 +1053,16 @@ class InferenceProvider(Protocol):
async def rerank( async def rerank(
self, self,
model: str, model: str,
query: str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam, query: (
items: list[str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam], str
| OpenAIChatCompletionContentPartTextParam
| OpenAIChatCompletionContentPartImageParam
),
items: list[
str
| OpenAIChatCompletionContentPartTextParam
| OpenAIChatCompletionContentPartImageParam
],
max_num_results: int | None = None, max_num_results: int | None = None,
) -> RerankResponse: ) -> RerankResponse:
"""Rerank a list of documents based on their relevance to a query. """Rerank a list of documents based on their relevance to a query.
@ -1064,7 +1076,12 @@ class InferenceProvider(Protocol):
raise NotImplementedError("Reranking is not implemented") raise NotImplementedError("Reranking is not implemented")
return # this is so mypy's safe-super rule will consider the method concrete return # this is so mypy's safe-super rule will consider the method concrete
@webmethod(route="/openai/v1/completions", method="POST", level=LLAMA_STACK_API_V1, deprecated=True) @webmethod(
route="/openai/v1/completions",
method="POST",
level=LLAMA_STACK_API_V1,
deprecated=True,
)
@webmethod(route="/completions", method="POST", level=LLAMA_STACK_API_V1) @webmethod(route="/completions", method="POST", level=LLAMA_STACK_API_V1)
async def openai_completion( async def openai_completion(
self, self,
@ -1116,7 +1133,12 @@ class InferenceProvider(Protocol):
""" """
... ...
@webmethod(route="/openai/v1/chat/completions", method="POST", level=LLAMA_STACK_API_V1, deprecated=True) @webmethod(
route="/openai/v1/chat/completions",
method="POST",
level=LLAMA_STACK_API_V1,
deprecated=True,
)
@webmethod(route="/chat/completions", method="POST", level=LLAMA_STACK_API_V1) @webmethod(route="/chat/completions", method="POST", level=LLAMA_STACK_API_V1)
async def openai_chat_completion( async def openai_chat_completion(
self, self,
@ -1173,7 +1195,12 @@ class InferenceProvider(Protocol):
""" """
... ...
@webmethod(route="/openai/v1/embeddings", method="POST", level=LLAMA_STACK_API_V1, deprecated=True) @webmethod(
route="/openai/v1/embeddings",
method="POST",
level=LLAMA_STACK_API_V1,
deprecated=True,
)
@webmethod(route="/embeddings", method="POST", level=LLAMA_STACK_API_V1) @webmethod(route="/embeddings", method="POST", level=LLAMA_STACK_API_V1)
async def openai_embeddings( async def openai_embeddings(
self, self,
@ -1203,7 +1230,12 @@ class Inference(InferenceProvider):
- Embedding models: these models generate embeddings to be used for semantic search. - Embedding models: these models generate embeddings to be used for semantic search.
""" """
@webmethod(route="/openai/v1/chat/completions", method="GET", level=LLAMA_STACK_API_V1, deprecated=True) @webmethod(
route="/openai/v1/chat/completions",
method="GET",
level=LLAMA_STACK_API_V1,
deprecated=True,
)
@webmethod(route="/chat/completions", method="GET", level=LLAMA_STACK_API_V1) @webmethod(route="/chat/completions", method="GET", level=LLAMA_STACK_API_V1)
async def list_chat_completions( async def list_chat_completions(
self, self,
@ -1223,10 +1255,19 @@ class Inference(InferenceProvider):
raise NotImplementedError("List chat completions is not implemented") raise NotImplementedError("List chat completions is not implemented")
@webmethod( @webmethod(
route="/openai/v1/chat/completions/{completion_id}", method="GET", level=LLAMA_STACK_API_V1, deprecated=True route="/openai/v1/chat/completions/{completion_id}",
method="GET",
level=LLAMA_STACK_API_V1,
deprecated=True,
) )
@webmethod(route="/chat/completions/{completion_id}", method="GET", level=LLAMA_STACK_API_V1) @webmethod(
async def get_chat_completion(self, completion_id: str) -> OpenAICompletionWithInputMessages: route="/chat/completions/{completion_id}",
method="GET",
level=LLAMA_STACK_API_V1,
)
async def get_chat_completion(
self, completion_id: str
) -> OpenAICompletionWithInputMessages:
"""Describe a chat completion by its ID. """Describe a chat completion by its ID.
:param completion_id: ID of the chat completion. :param completion_id: ID of the chat completion.

View file

@ -7,24 +7,16 @@
import asyncio import asyncio
import time import time
from collections.abc import AsyncGenerator, AsyncIterator from collections.abc import AsyncGenerator, AsyncIterator
from datetime import UTC, datetime from datetime import datetime, UTC
from typing import Annotated, Any from typing import Annotated, Any
from openai.types.chat import ChatCompletionToolChoiceOptionParam as OpenAIChatCompletionToolChoiceOptionParam from llama_stack.apis.common.content_types import InterleavedContent
from openai.types.chat import ChatCompletionToolParam as OpenAIChatCompletionToolParam
from pydantic import Field, TypeAdapter
from llama_stack.apis.common.content_types import (
InterleavedContent,
)
from llama_stack.apis.common.errors import ModelNotFoundError, ModelTypeError from llama_stack.apis.common.errors import ModelNotFoundError, ModelTypeError
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
ChatCompletionResponse, ChatCompletionResponse,
ChatCompletionResponseEventType, ChatCompletionResponseEventType,
ChatCompletionResponseStreamChunk, ChatCompletionResponseStreamChunk,
CompletionMessage, CompletionMessage,
CompletionResponse,
CompletionResponseStreamChunk,
Inference, Inference,
ListOpenAIChatCompletionResponse, ListOpenAIChatCompletionResponse,
LogProbConfig, LogProbConfig,
@ -57,7 +49,16 @@ from llama_stack.models.llama.llama3.chat_format import ChatFormat
from llama_stack.models.llama.llama3.tokenizer import Tokenizer from llama_stack.models.llama.llama3.tokenizer import Tokenizer
from llama_stack.providers.datatypes import HealthResponse, HealthStatus, RoutingTable from llama_stack.providers.datatypes import HealthResponse, HealthStatus, RoutingTable
from llama_stack.providers.utils.inference.inference_store import InferenceStore from llama_stack.providers.utils.inference.inference_store import InferenceStore
from llama_stack.providers.utils.telemetry.tracing import enqueue_event, get_current_span from llama_stack.providers.utils.telemetry.tracing import (
enqueue_event,
get_current_span,
)
from openai.types.chat import (
ChatCompletionToolChoiceOptionParam as OpenAIChatCompletionToolChoiceOptionParam,
ChatCompletionToolParam as OpenAIChatCompletionToolParam,
)
from pydantic import Field, TypeAdapter
logger = get_logger(name=__name__, category="core::routers") logger = get_logger(name=__name__, category="core::routers")
@ -101,7 +102,9 @@ class InferenceRouter(Inference):
logger.debug( logger.debug(
f"InferenceRouter.register_model: {model_id=} {provider_model_id=} {provider_id=} {metadata=} {model_type=}", f"InferenceRouter.register_model: {model_id=} {provider_model_id=} {provider_id=} {metadata=} {model_type=}",
) )
await self.routing_table.register_model(model_id, provider_model_id, provider_id, metadata, model_type) await self.routing_table.register_model(
model_id, provider_model_id, provider_id, metadata, model_type
)
def _construct_metrics( def _construct_metrics(
self, self,
@ -156,11 +159,16 @@ class InferenceRouter(Inference):
total_tokens: int, total_tokens: int,
model: Model, model: Model,
) -> list[MetricInResponse]: ) -> list[MetricInResponse]:
metrics = self._construct_metrics(prompt_tokens, completion_tokens, total_tokens, model) metrics = self._construct_metrics(
prompt_tokens, completion_tokens, total_tokens, model
)
if self.telemetry: if self.telemetry:
for metric in metrics: for metric in metrics:
enqueue_event(metric) enqueue_event(metric)
return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in metrics] return [
MetricInResponse(metric=metric.metric, value=metric.value)
for metric in metrics
]
async def _count_tokens( async def _count_tokens(
self, self,
@ -207,8 +215,13 @@ class InferenceRouter(Inference):
if tool_config: if tool_config:
if tool_choice and tool_choice != tool_config.tool_choice: if tool_choice and tool_choice != tool_config.tool_choice:
raise ValueError("tool_choice and tool_config.tool_choice must match") raise ValueError("tool_choice and tool_config.tool_choice must match")
if tool_prompt_format and tool_prompt_format != tool_config.tool_prompt_format: if (
raise ValueError("tool_prompt_format and tool_config.tool_prompt_format must match") tool_prompt_format
and tool_prompt_format != tool_config.tool_prompt_format
):
raise ValueError(
"tool_prompt_format and tool_config.tool_prompt_format must match"
)
else: else:
params = {} params = {}
if tool_choice: if tool_choice:
@ -226,9 +239,14 @@ class InferenceRouter(Inference):
pass pass
else: else:
# verify tool_choice is one of the tools # verify tool_choice is one of the tools
tool_names = [t.tool_name if isinstance(t.tool_name, str) else t.tool_name.value for t in tools] tool_names = [
t.tool_name if isinstance(t.tool_name, str) else t.tool_name.value
for t in tools
]
if tool_config.tool_choice not in tool_names: if tool_config.tool_choice not in tool_names:
raise ValueError(f"Tool choice {tool_config.tool_choice} is not one of the tools: {tool_names}") raise ValueError(
f"Tool choice {tool_config.tool_choice} is not one of the tools: {tool_names}"
)
params = dict( params = dict(
model_id=model_id, model_id=model_id,
@ -243,7 +261,9 @@ class InferenceRouter(Inference):
tool_config=tool_config, tool_config=tool_config,
) )
provider = await self.routing_table.get_provider_impl(model_id) provider = await self.routing_table.get_provider_impl(model_id)
prompt_tokens = await self._count_tokens(messages, tool_config.tool_prompt_format) prompt_tokens = await self._count_tokens(
messages, tool_config.tool_prompt_format
)
if stream: if stream:
response_stream = await provider.chat_completion(**params) response_stream = await provider.chat_completion(**params)
@ -263,7 +283,9 @@ class InferenceRouter(Inference):
) )
# these metrics will show up in the client response. # these metrics will show up in the client response.
response.metrics = ( response.metrics = (
metrics if not hasattr(response, "metrics") or response.metrics is None else response.metrics + metrics metrics
if not hasattr(response, "metrics") or response.metrics is None
else response.metrics + metrics
) )
return response return response
@ -336,7 +358,9 @@ class InferenceRouter(Inference):
# these metrics will show up in the client response. # these metrics will show up in the client response.
response.metrics = ( response.metrics = (
metrics if not hasattr(response, "metrics") or response.metrics is None else response.metrics + metrics metrics
if not hasattr(response, "metrics") or response.metrics is None
else response.metrics + metrics
) )
return response return response
@ -374,9 +398,13 @@ class InferenceRouter(Inference):
# Use the OpenAI client for a bit of extra input validation without # Use the OpenAI client for a bit of extra input validation without
# exposing the OpenAI client itself as part of our API surface # exposing the OpenAI client itself as part of our API surface
if tool_choice: if tool_choice:
TypeAdapter(OpenAIChatCompletionToolChoiceOptionParam).validate_python(tool_choice) TypeAdapter(OpenAIChatCompletionToolChoiceOptionParam).validate_python(
tool_choice
)
if tools is None: if tools is None:
raise ValueError("'tool_choice' is only allowed when 'tools' is also provided") raise ValueError(
"'tool_choice' is only allowed when 'tools' is also provided"
)
if tools: if tools:
for tool in tools: for tool in tools:
TypeAdapter(OpenAIChatCompletionToolParam).validate_python(tool) TypeAdapter(OpenAIChatCompletionToolParam).validate_python(tool)
@ -441,7 +469,9 @@ class InferenceRouter(Inference):
enqueue_event(metric) enqueue_event(metric)
# these metrics will show up in the client response. # these metrics will show up in the client response.
response.metrics = ( response.metrics = (
metrics if not hasattr(response, "metrics") or response.metrics is None else response.metrics + metrics metrics
if not hasattr(response, "metrics") or response.metrics is None
else response.metrics + metrics
) )
return response return response
@ -477,19 +507,31 @@ class InferenceRouter(Inference):
) -> ListOpenAIChatCompletionResponse: ) -> ListOpenAIChatCompletionResponse:
if self.store: if self.store:
return await self.store.list_chat_completions(after, limit, model, order) return await self.store.list_chat_completions(after, limit, model, order)
raise NotImplementedError("List chat completions is not supported: inference store is not configured.") raise NotImplementedError(
"List chat completions is not supported: inference store is not configured."
)
async def get_chat_completion(self, completion_id: str) -> OpenAICompletionWithInputMessages: async def get_chat_completion(
self, completion_id: str
) -> OpenAICompletionWithInputMessages:
if self.store: if self.store:
return await self.store.get_chat_completion(completion_id) return await self.store.get_chat_completion(completion_id)
raise NotImplementedError("Get chat completion is not supported: inference store is not configured.") raise NotImplementedError(
"Get chat completion is not supported: inference store is not configured."
)
async def _nonstream_openai_chat_completion(self, provider: Inference, params: dict) -> OpenAIChatCompletion: async def _nonstream_openai_chat_completion(
self, provider: Inference, params: dict
) -> OpenAIChatCompletion:
response = await provider.openai_chat_completion(**params) response = await provider.openai_chat_completion(**params)
for choice in response.choices: for choice in response.choices:
# some providers return an empty list for no tool calls in non-streaming responses # some providers return an empty list for no tool calls in non-streaming responses
# but the OpenAI API returns None. So, set tool_calls to None if it's empty # but the OpenAI API returns None. So, set tool_calls to None if it's empty
if choice.message and choice.message.tool_calls is not None and len(choice.message.tool_calls) == 0: if (
choice.message
and choice.message.tool_calls is not None
and len(choice.message.tool_calls) == 0
):
choice.message.tool_calls = None choice.message.tool_calls = None
return response return response
@ -509,7 +551,9 @@ class InferenceRouter(Inference):
message=f"Health check timed out after {timeout} seconds", message=f"Health check timed out after {timeout} seconds",
) )
except NotImplementedError: except NotImplementedError:
health_statuses[provider_id] = HealthResponse(status=HealthStatus.NOT_IMPLEMENTED) health_statuses[provider_id] = HealthResponse(
status=HealthStatus.NOT_IMPLEMENTED
)
except Exception as e: except Exception as e:
health_statuses[provider_id] = HealthResponse( health_statuses[provider_id] = HealthResponse(
status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}" status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}"
@ -522,7 +566,7 @@ class InferenceRouter(Inference):
prompt_tokens, prompt_tokens,
model, model,
tool_prompt_format: ToolPromptFormat | None = None, tool_prompt_format: ToolPromptFormat | None = None,
) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None] | AsyncGenerator[CompletionResponseStreamChunk, None]: ) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None]:
completion_text = "" completion_text = ""
async for chunk in response: async for chunk in response:
complete = False complete = False
@ -544,7 +588,11 @@ class InferenceRouter(Inference):
else: else:
if hasattr(chunk, "delta"): if hasattr(chunk, "delta"):
completion_text += chunk.delta completion_text += chunk.delta
if hasattr(chunk, "stop_reason") and chunk.stop_reason and self.telemetry: if (
hasattr(chunk, "stop_reason")
and chunk.stop_reason
and self.telemetry
):
complete = True complete = True
completion_tokens = await self._count_tokens(completion_text) completion_tokens = await self._count_tokens(completion_text)
# if we are done receiving tokens # if we are done receiving tokens
@ -569,9 +617,14 @@ class InferenceRouter(Inference):
# Return metrics in response # Return metrics in response
async_metrics = [ async_metrics = [
MetricInResponse(metric=metric.metric, value=metric.value) for metric in completion_metrics MetricInResponse(metric=metric.metric, value=metric.value)
for metric in completion_metrics
] ]
chunk.metrics = async_metrics if chunk.metrics is None else chunk.metrics + async_metrics chunk.metrics = (
async_metrics
if chunk.metrics is None
else chunk.metrics + async_metrics
)
else: else:
# Fallback if no telemetry # Fallback if no telemetry
completion_metrics = self._construct_metrics( completion_metrics = self._construct_metrics(
@ -581,14 +634,19 @@ class InferenceRouter(Inference):
model, model,
) )
async_metrics = [ async_metrics = [
MetricInResponse(metric=metric.metric, value=metric.value) for metric in completion_metrics MetricInResponse(metric=metric.metric, value=metric.value)
for metric in completion_metrics
] ]
chunk.metrics = async_metrics if chunk.metrics is None else chunk.metrics + async_metrics chunk.metrics = (
async_metrics
if chunk.metrics is None
else chunk.metrics + async_metrics
)
yield chunk yield chunk
async def count_tokens_and_compute_metrics( async def count_tokens_and_compute_metrics(
self, self,
response: ChatCompletionResponse | CompletionResponse, response: ChatCompletionResponse,
prompt_tokens, prompt_tokens,
model, model,
tool_prompt_format: ToolPromptFormat | None = None, tool_prompt_format: ToolPromptFormat | None = None,
@ -597,7 +655,9 @@ class InferenceRouter(Inference):
content = [response.completion_message] content = [response.completion_message]
else: else:
content = response.content content = response.content
completion_tokens = await self._count_tokens(messages=content, tool_prompt_format=tool_prompt_format) completion_tokens = await self._count_tokens(
messages=content, tool_prompt_format=tool_prompt_format
)
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0) total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
# Create a separate span for completion metrics # Create a separate span for completion metrics
@ -610,11 +670,17 @@ class InferenceRouter(Inference):
model=model, model=model,
) )
for metric in completion_metrics: for metric in completion_metrics:
if metric.metric in ["completion_tokens", "total_tokens"]: # Only log completion and total tokens if metric.metric in [
"completion_tokens",
"total_tokens",
]: # Only log completion and total tokens
enqueue_event(metric) enqueue_event(metric)
# Return metrics in response # Return metrics in response
return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in completion_metrics] return [
MetricInResponse(metric=metric.metric, value=metric.value)
for metric in completion_metrics
]
# Fallback if no telemetry # Fallback if no telemetry
metrics = self._construct_metrics( metrics = self._construct_metrics(
@ -623,7 +689,10 @@ class InferenceRouter(Inference):
total_tokens, total_tokens,
model, model,
) )
return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in metrics] return [
MetricInResponse(metric=metric.metric, value=metric.value)
for metric in metrics
]
async def stream_tokens_and_compute_metrics_openai_chat( async def stream_tokens_and_compute_metrics_openai_chat(
self, self,
@ -664,33 +733,48 @@ class InferenceRouter(Inference):
if choice_delta.delta: if choice_delta.delta:
delta = choice_delta.delta delta = choice_delta.delta
if delta.content: if delta.content:
current_choice_data["content_parts"].append(delta.content) current_choice_data["content_parts"].append(
delta.content
)
if delta.tool_calls: if delta.tool_calls:
for tool_call_delta in delta.tool_calls: for tool_call_delta in delta.tool_calls:
tc_idx = tool_call_delta.index tc_idx = tool_call_delta.index
if tc_idx not in current_choice_data["tool_calls_builder"]: if (
current_choice_data["tool_calls_builder"][tc_idx] = { tc_idx
not in current_choice_data["tool_calls_builder"]
):
current_choice_data["tool_calls_builder"][
tc_idx
] = {
"id": None, "id": None,
"type": "function", "type": "function",
"function_name_parts": [], "function_name_parts": [],
"function_arguments_parts": [], "function_arguments_parts": [],
} }
builder = current_choice_data["tool_calls_builder"][tc_idx] builder = current_choice_data["tool_calls_builder"][
tc_idx
]
if tool_call_delta.id: if tool_call_delta.id:
builder["id"] = tool_call_delta.id builder["id"] = tool_call_delta.id
if tool_call_delta.type: if tool_call_delta.type:
builder["type"] = tool_call_delta.type builder["type"] = tool_call_delta.type
if tool_call_delta.function: if tool_call_delta.function:
if tool_call_delta.function.name: if tool_call_delta.function.name:
builder["function_name_parts"].append(tool_call_delta.function.name) builder["function_name_parts"].append(
tool_call_delta.function.name
)
if tool_call_delta.function.arguments: if tool_call_delta.function.arguments:
builder["function_arguments_parts"].append( builder["function_arguments_parts"].append(
tool_call_delta.function.arguments tool_call_delta.function.arguments
) )
if choice_delta.finish_reason: if choice_delta.finish_reason:
current_choice_data["finish_reason"] = choice_delta.finish_reason current_choice_data["finish_reason"] = (
choice_delta.finish_reason
)
if choice_delta.logprobs and choice_delta.logprobs.content: if choice_delta.logprobs and choice_delta.logprobs.content:
current_choice_data["logprobs_content_parts"].extend(choice_delta.logprobs.content) current_choice_data["logprobs_content_parts"].extend(
choice_delta.logprobs.content
)
# Compute metrics on final chunk # Compute metrics on final chunk
if chunk.choices and chunk.choices[0].finish_reason: if chunk.choices and chunk.choices[0].finish_reason:
@ -720,8 +804,12 @@ class InferenceRouter(Inference):
if choice_data["tool_calls_builder"]: if choice_data["tool_calls_builder"]:
for tc_build_data in choice_data["tool_calls_builder"].values(): for tc_build_data in choice_data["tool_calls_builder"].values():
if tc_build_data["id"]: if tc_build_data["id"]:
func_name = "".join(tc_build_data["function_name_parts"]) func_name = "".join(
func_args = "".join(tc_build_data["function_arguments_parts"]) tc_build_data["function_name_parts"]
)
func_args = "".join(
tc_build_data["function_arguments_parts"]
)
assembled_tool_calls.append( assembled_tool_calls.append(
OpenAIChatCompletionToolCall( OpenAIChatCompletionToolCall(
id=tc_build_data["id"], id=tc_build_data["id"],
@ -734,10 +822,16 @@ class InferenceRouter(Inference):
message = OpenAIAssistantMessageParam( message = OpenAIAssistantMessageParam(
role="assistant", role="assistant",
content=content_str if content_str else None, content=content_str if content_str else None,
tool_calls=assembled_tool_calls if assembled_tool_calls else None, tool_calls=(
assembled_tool_calls if assembled_tool_calls else None
),
) )
logprobs_content = choice_data["logprobs_content_parts"] logprobs_content = choice_data["logprobs_content_parts"]
final_logprobs = OpenAIChoiceLogprobs(content=logprobs_content) if logprobs_content else None final_logprobs = (
OpenAIChoiceLogprobs(content=logprobs_content)
if logprobs_content
else None
)
assembled_choices.append( assembled_choices.append(
OpenAIChoice( OpenAIChoice(
@ -756,4 +850,6 @@ class InferenceRouter(Inference):
object="chat.completion", object="chat.completion",
) )
logger.debug(f"InferenceRouter.completion_response: {final_response}") logger.debug(f"InferenceRouter.completion_response: {final_response}")
asyncio.create_task(self.store.store_chat_completion(final_response, messages)) asyncio.create_task(
self.store.store_chat_completion(final_response, messages)
)

View file

@ -25,9 +25,6 @@ from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
from llama_stack.providers.utils.inference.embedding_mixin import ( from llama_stack.providers.utils.inference.embedding_mixin import (
SentenceTransformerEmbeddingMixin, SentenceTransformerEmbeddingMixin,
) )
from llama_stack.providers.utils.inference.openai_compat import (
OpenAIChatCompletionToLlamaStackMixin,
)
from .config import SentenceTransformersInferenceConfig from .config import SentenceTransformersInferenceConfig
@ -35,7 +32,6 @@ log = get_logger(name=__name__, category="inference")
class SentenceTransformersInferenceImpl( class SentenceTransformersInferenceImpl(
OpenAIChatCompletionToLlamaStackMixin,
SentenceTransformerEmbeddingMixin, SentenceTransformerEmbeddingMixin,
InferenceProvider, InferenceProvider,
ModelsProtocolPrivate, ModelsProtocolPrivate,
@ -114,4 +110,6 @@ class SentenceTransformersInferenceImpl(
# for fill-in-the-middle type completion # for fill-in-the-middle type completion
suffix: str | None = None, suffix: str | None = None,
) -> OpenAICompletion: ) -> OpenAICompletion:
raise NotImplementedError("OpenAI completion not supported by sentence transformers provider") raise NotImplementedError(
"OpenAI completion not supported by sentence transformers provider"
)

View file

@ -11,8 +11,7 @@ from cerebras.cloud.sdk import AsyncCerebras
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
ChatCompletionRequest, ChatCompletionRequest,
CompletionRequest, ChatCompletionResponse,
CompletionResponse,
Inference, Inference,
LogProbConfig, LogProbConfig,
Message, Message,
@ -25,9 +24,7 @@ from llama_stack.apis.inference import (
ToolPromptFormat, ToolPromptFormat,
TopKSamplingStrategy, TopKSamplingStrategy,
) )
from llama_stack.providers.utils.inference.model_registry import ( from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
ModelRegistryHelper,
)
from llama_stack.providers.utils.inference.openai_compat import ( from llama_stack.providers.utils.inference.openai_compat import (
get_sampling_options, get_sampling_options,
process_chat_completion_response, process_chat_completion_response,
@ -36,7 +33,6 @@ from llama_stack.providers.utils.inference.openai_compat import (
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from llama_stack.providers.utils.inference.prompt_adapter import ( from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt, chat_completion_request_to_prompt,
completion_request_to_prompt,
) )
from .config import CerebrasImplConfig from .config import CerebrasImplConfig
@ -102,14 +98,18 @@ class CerebrasInferenceAdapter(
else: else:
return await self._nonstream_chat_completion(request) return await self._nonstream_chat_completion(request)
async def _nonstream_chat_completion(self, request: CompletionRequest) -> CompletionResponse: async def _nonstream_chat_completion(
self, request: ChatCompletionRequest
) -> ChatCompletionResponse:
params = await self._get_params(request) params = await self._get_params(request)
r = await self._cerebras_client.completions.create(**params) r = await self._cerebras_client.completions.create(**params)
return process_chat_completion_response(r, request) return process_chat_completion_response(r, request)
async def _stream_chat_completion(self, request: CompletionRequest) -> AsyncGenerator: async def _stream_chat_completion(
self, request: ChatCompletionRequest
) -> AsyncGenerator:
params = await self._get_params(request) params = await self._get_params(request)
stream = await self._cerebras_client.completions.create(**params) stream = await self._cerebras_client.completions.create(**params)
@ -117,15 +117,17 @@ class CerebrasInferenceAdapter(
async for chunk in process_chat_completion_stream_response(stream, request): async for chunk in process_chat_completion_stream_response(stream, request):
yield chunk yield chunk
async def _get_params(self, request: ChatCompletionRequest | CompletionRequest) -> dict: async def _get_params(self, request: ChatCompletionRequest) -> dict:
if request.sampling_params and isinstance(request.sampling_params.strategy, TopKSamplingStrategy): if request.sampling_params and isinstance(
request.sampling_params.strategy, TopKSamplingStrategy
):
raise ValueError("`top_k` not supported by Cerebras") raise ValueError("`top_k` not supported by Cerebras")
prompt = "" prompt = ""
if isinstance(request, ChatCompletionRequest): if isinstance(request, ChatCompletionRequest):
prompt = await chat_completion_request_to_prompt(request, self.get_llama_model(request.model)) prompt = await chat_completion_request_to_prompt(
elif isinstance(request, CompletionRequest): request, self.get_llama_model(request.model)
prompt = await completion_request_to_prompt(request) )
else: else:
raise ValueError(f"Unknown request type {type(request)}") raise ValueError(f"Unknown request type {type(request)}")

View file

@ -10,11 +10,13 @@ from openai import OpenAI
from llama_stack.apis.inference import * # noqa: F403 from llama_stack.apis.inference import * # noqa: F403
from llama_stack.apis.inference import OpenAIEmbeddingsResponse from llama_stack.apis.inference import OpenAIEmbeddingsResponse
# from llama_stack.providers.datatypes import ModelsProtocolPrivate from llama_stack.providers.utils.inference.model_registry import (
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper, build_hf_repo_model_entry build_hf_repo_model_entry,
ModelRegistryHelper,
)
from llama_stack.providers.utils.inference.openai_compat import ( from llama_stack.providers.utils.inference.openai_compat import (
OpenAIChatCompletionToLlamaStackMixin,
get_sampling_options, get_sampling_options,
OpenAIChatCompletionToLlamaStackMixin,
process_chat_completion_response, process_chat_completion_response,
process_chat_completion_stream_response, process_chat_completion_stream_response,
) )
@ -41,13 +43,12 @@ RUNPOD_SUPPORTED_MODELS = {
"Llama3.2-3B": "meta-llama/Llama-3.2-3B", "Llama3.2-3B": "meta-llama/Llama-3.2-3B",
} }
SAFETY_MODELS_ENTRIES = []
# Create MODEL_ENTRIES from RUNPOD_SUPPORTED_MODELS for compatibility with starter template # Create MODEL_ENTRIES from RUNPOD_SUPPORTED_MODELS for compatibility with starter template
MODEL_ENTRIES = [ MODEL_ENTRIES = [
build_hf_repo_model_entry(provider_model_id, model_descriptor) build_hf_repo_model_entry(provider_model_id, model_descriptor)
for provider_model_id, model_descriptor in RUNPOD_SUPPORTED_MODELS.items() for provider_model_id, model_descriptor in RUNPOD_SUPPORTED_MODELS.items()
] + SAFETY_MODELS_ENTRIES ]
class RunpodInferenceAdapter( class RunpodInferenceAdapter(
@ -56,7 +57,9 @@ class RunpodInferenceAdapter(
OpenAIChatCompletionToLlamaStackMixin, OpenAIChatCompletionToLlamaStackMixin,
): ):
def __init__(self, config: RunpodImplConfig) -> None: def __init__(self, config: RunpodImplConfig) -> None:
ModelRegistryHelper.__init__(self, stack_to_provider_models_map=RUNPOD_SUPPORTED_MODELS) ModelRegistryHelper.__init__(
self, stack_to_provider_models_map=RUNPOD_SUPPORTED_MODELS
)
self.config = config self.config = config
async def initialize(self) -> None: async def initialize(self) -> None:
@ -103,7 +106,9 @@ class RunpodInferenceAdapter(
r = client.completions.create(**params) r = client.completions.create(**params)
return process_chat_completion_response(r, request) return process_chat_completion_response(r, request)
async def _stream_chat_completion(self, request: ChatCompletionRequest, client: OpenAI) -> AsyncGenerator: async def _stream_chat_completion(
self, request: ChatCompletionRequest, client: OpenAI
) -> AsyncGenerator:
params = self._get_params(request) params = self._get_params(request)
async def _to_async_generator(): async def _to_async_generator():

View file

@ -9,12 +9,10 @@ from typing import Any
from ibm_watsonx_ai.foundation_models import Model from ibm_watsonx_ai.foundation_models import Model
from ibm_watsonx_ai.metanames import GenTextParamsMetaNames as GenParams from ibm_watsonx_ai.metanames import GenTextParamsMetaNames as GenParams
from openai import AsyncOpenAI
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
ChatCompletionRequest, ChatCompletionRequest,
ChatCompletionResponse, ChatCompletionResponse,
CompletionRequest,
GreedySamplingStrategy, GreedySamplingStrategy,
Inference, Inference,
LogProbConfig, LogProbConfig,
@ -48,6 +46,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
completion_request_to_prompt, completion_request_to_prompt,
request_has_media, request_has_media,
) )
from openai import AsyncOpenAI
from . import WatsonXConfig from . import WatsonXConfig
from .models import MODEL_ENTRIES from .models import MODEL_ENTRIES
@ -85,7 +84,9 @@ class WatsonXInferenceAdapter(Inference, ModelRegistryHelper):
pass pass
def _get_client(self, model_id) -> Model: def _get_client(self, model_id) -> Model:
config_api_key = self._config.api_key.get_secret_value() if self._config.api_key else None config_api_key = (
self._config.api_key.get_secret_value() if self._config.api_key else None
)
config_url = self._config.url config_url = self._config.url
project_id = self._config.project_id project_id = self._config.project_id
credentials = {"url": config_url, "apikey": config_api_key} credentials = {"url": config_url, "apikey": config_api_key}
@ -132,14 +133,18 @@ class WatsonXInferenceAdapter(Inference, ModelRegistryHelper):
else: else:
return await self._nonstream_chat_completion(request) return await self._nonstream_chat_completion(request)
async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse: async def _nonstream_chat_completion(
self, request: ChatCompletionRequest
) -> ChatCompletionResponse:
params = await self._get_params(request) params = await self._get_params(request)
r = self._get_client(request.model).generate(**params) r = self._get_client(request.model).generate(**params)
choices = [] choices = []
if "results" in r: if "results" in r:
for result in r["results"]: for result in r["results"]:
choice = OpenAICompatCompletionChoice( choice = OpenAICompatCompletionChoice(
finish_reason=result["stop_reason"] if result["stop_reason"] else None, finish_reason=(
result["stop_reason"] if result["stop_reason"] else None
),
text=result["generated_text"], text=result["generated_text"],
) )
choices.append(choice) choices.append(choice)
@ -148,7 +153,9 @@ class WatsonXInferenceAdapter(Inference, ModelRegistryHelper):
) )
return process_chat_completion_response(response, request) return process_chat_completion_response(response, request)
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator: async def _stream_chat_completion(
self, request: ChatCompletionRequest
) -> AsyncGenerator:
params = await self._get_params(request) params = await self._get_params(request)
model_id = request.model model_id = request.model
@ -168,28 +175,44 @@ class WatsonXInferenceAdapter(Inference, ModelRegistryHelper):
async for chunk in process_chat_completion_stream_response(stream, request): async for chunk in process_chat_completion_stream_response(stream, request):
yield chunk yield chunk
async def _get_params(self, request: ChatCompletionRequest | CompletionRequest) -> dict: async def _get_params(self, request: ChatCompletionRequest) -> dict:
input_dict = {"params": {}} input_dict = {"params": {}}
media_present = request_has_media(request) media_present = request_has_media(request)
llama_model = self.get_llama_model(request.model) llama_model = self.get_llama_model(request.model)
if isinstance(request, ChatCompletionRequest): if isinstance(request, ChatCompletionRequest):
input_dict["prompt"] = await chat_completion_request_to_prompt(request, llama_model) input_dict["prompt"] = await chat_completion_request_to_prompt(
request, llama_model
)
else: else:
assert not media_present, "Together does not support media for Completion requests" assert (
not media_present
), "Together does not support media for Completion requests"
input_dict["prompt"] = await completion_request_to_prompt(request) input_dict["prompt"] = await completion_request_to_prompt(request)
if request.sampling_params: if request.sampling_params:
if request.sampling_params.strategy: if request.sampling_params.strategy:
input_dict["params"][GenParams.DECODING_METHOD] = request.sampling_params.strategy.type input_dict["params"][
GenParams.DECODING_METHOD
] = request.sampling_params.strategy.type
if request.sampling_params.max_tokens: if request.sampling_params.max_tokens:
input_dict["params"][GenParams.MAX_NEW_TOKENS] = request.sampling_params.max_tokens input_dict["params"][
GenParams.MAX_NEW_TOKENS
] = request.sampling_params.max_tokens
if request.sampling_params.repetition_penalty: if request.sampling_params.repetition_penalty:
input_dict["params"][GenParams.REPETITION_PENALTY] = request.sampling_params.repetition_penalty input_dict["params"][
GenParams.REPETITION_PENALTY
] = request.sampling_params.repetition_penalty
if isinstance(request.sampling_params.strategy, TopPSamplingStrategy): if isinstance(request.sampling_params.strategy, TopPSamplingStrategy):
input_dict["params"][GenParams.TOP_P] = request.sampling_params.strategy.top_p input_dict["params"][
input_dict["params"][GenParams.TEMPERATURE] = request.sampling_params.strategy.temperature GenParams.TOP_P
] = request.sampling_params.strategy.top_p
input_dict["params"][
GenParams.TEMPERATURE
] = request.sampling_params.strategy.temperature
if isinstance(request.sampling_params.strategy, TopKSamplingStrategy): if isinstance(request.sampling_params.strategy, TopKSamplingStrategy):
input_dict["params"][GenParams.TOP_K] = request.sampling_params.strategy.top_k input_dict["params"][
GenParams.TOP_K
] = request.sampling_params.strategy.top_k
if isinstance(request.sampling_params.strategy, GreedySamplingStrategy): if isinstance(request.sampling_params.strategy, GreedySamplingStrategy):
input_dict["params"][GenParams.TEMPERATURE] = 0.0 input_dict["params"][GenParams.TEMPERATURE] = 0.0