Merge branch 'main' into dead_code_removal

This commit is contained in:
Omar Abdelwahab 2025-10-06 13:21:36 -07:00 committed by GitHub
commit 9886520b40
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
927 changed files with 171924 additions and 102933 deletions

View file

@ -19,7 +19,6 @@ from llama_stack.apis.inference import (
CompletionMessage,
Inference,
ListOpenAIChatCompletionResponse,
LogProbConfig,
Message,
OpenAIAssistantMessageParam,
OpenAIChatCompletion,
@ -34,12 +33,7 @@ from llama_stack.apis.inference import (
OpenAIMessageParam,
OpenAIResponseFormatParam,
Order,
ResponseFormat,
SamplingParams,
StopReason,
ToolChoice,
ToolConfig,
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.apis.models import Model, ModelType
@ -193,102 +187,6 @@ class InferenceRouter(Inference):
raise ModelTypeError(model_id, model.model_type, expected_model_type)
return model
async def chat_completion(
self,
model_id: str,
messages: list[Message],
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
tools: list[ToolDefinition] | None = None,
tool_choice: ToolChoice | None = None,
tool_prompt_format: ToolPromptFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
tool_config: ToolConfig | None = None,
) -> ChatCompletionResponse | AsyncIterator[ChatCompletionResponseStreamChunk]:
logger.debug(
f"InferenceRouter.chat_completion: {model_id=}, {stream=}, {messages=}, {tools=}, {tool_config=}, {response_format=}",
)
if sampling_params is None:
sampling_params = SamplingParams()
model = await self._get_model(model_id, ModelType.llm)
if tool_config:
if tool_choice and tool_choice != tool_config.tool_choice:
raise ValueError("tool_choice and tool_config.tool_choice must match")
if (
tool_prompt_format
and tool_prompt_format != tool_config.tool_prompt_format
):
raise ValueError(
"tool_prompt_format and tool_config.tool_prompt_format must match"
)
else:
params = {}
if tool_choice:
params["tool_choice"] = tool_choice
if tool_prompt_format:
params["tool_prompt_format"] = tool_prompt_format
tool_config = ToolConfig(**params)
tools = tools or []
if tool_config.tool_choice == ToolChoice.none:
tools = []
elif tool_config.tool_choice == ToolChoice.auto:
pass
elif tool_config.tool_choice == ToolChoice.required:
pass
else:
# verify tool_choice is one of the tools
tool_names = [
t.tool_name if isinstance(t.tool_name, str) else t.tool_name.value
for t in tools
]
if tool_config.tool_choice not in tool_names:
raise ValueError(
f"Tool choice {tool_config.tool_choice} is not one of the tools: {tool_names}"
)
params = dict(
model_id=model_id,
messages=messages,
sampling_params=sampling_params,
tools=tools,
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
response_format=response_format,
stream=stream,
logprobs=logprobs,
tool_config=tool_config,
)
provider = await self.routing_table.get_provider_impl(model_id)
prompt_tokens = await self._count_tokens(
messages, tool_config.tool_prompt_format
)
if stream:
response_stream = await provider.chat_completion(**params)
return self.stream_tokens_and_compute_metrics(
response=response_stream,
prompt_tokens=prompt_tokens,
model=model,
tool_prompt_format=tool_config.tool_prompt_format,
)
response = await provider.chat_completion(**params)
metrics = await self.count_tokens_and_compute_metrics(
response=response,
prompt_tokens=prompt_tokens,
model=model,
tool_prompt_format=tool_config.tool_prompt_format,
)
# these metrics will show up in the client response.
response.metrics = (
metrics
if not hasattr(response, "metrics") or response.metrics is None
else response.metrics + metrics
)
return response
async def openai_completion(
self,
model: str,