mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 04:04:14 +00:00
chore: remove deprecated inference.chat_completion implementations
vllm - - requires max_tokens be set, use config value - set tool_choice to none if no tools provided
This commit is contained in:
parent
f1748e2f92
commit
f754e1b65b
18 changed files with 193 additions and 1411 deletions
|
@ -4,33 +4,22 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from collections.abc import AsyncGenerator, AsyncIterator
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Any
|
||||
|
||||
from llama_stack_client import AsyncLlamaStackClient
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseStreamChunk,
|
||||
CompletionMessage,
|
||||
Inference,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
OpenAIChatCompletion,
|
||||
OpenAIChatCompletionChunk,
|
||||
OpenAICompletion,
|
||||
OpenAIEmbeddingsResponse,
|
||||
OpenAIMessageParam,
|
||||
OpenAIResponseFormatParam,
|
||||
ResponseFormat,
|
||||
SamplingParams,
|
||||
ToolChoice,
|
||||
ToolConfig,
|
||||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.apis.models import Model
|
||||
from llama_stack.core.library_client import convert_pydantic_to_json_value, convert_to_pydantic
|
||||
from llama_stack.core.library_client import convert_pydantic_to_json_value
|
||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||
from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params
|
||||
|
||||
|
@ -85,76 +74,6 @@ class PassthroughInferenceAdapter(Inference):
|
|||
provider_data=provider_data,
|
||||
)
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
model_id: str,
|
||||
messages: list[Message],
|
||||
sampling_params: SamplingParams | None = None,
|
||||
tools: list[ToolDefinition] | None = None,
|
||||
tool_choice: ToolChoice | None = ToolChoice.auto,
|
||||
tool_prompt_format: ToolPromptFormat | None = None,
|
||||
response_format: ResponseFormat | None = None,
|
||||
stream: bool | None = False,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
tool_config: ToolConfig | None = None,
|
||||
) -> AsyncGenerator:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
model = await self.model_store.get_model(model_id)
|
||||
|
||||
# TODO: revisit this remove tool_calls from messages logic
|
||||
for message in messages:
|
||||
if hasattr(message, "tool_calls"):
|
||||
message.tool_calls = None
|
||||
|
||||
request_params = {
|
||||
"model_id": model.provider_resource_id,
|
||||
"messages": messages,
|
||||
"sampling_params": sampling_params,
|
||||
"tools": tools,
|
||||
"tool_choice": tool_choice,
|
||||
"tool_prompt_format": tool_prompt_format,
|
||||
"response_format": response_format,
|
||||
"stream": stream,
|
||||
"logprobs": logprobs,
|
||||
}
|
||||
|
||||
# only pass through the not None params
|
||||
request_params = {key: value for key, value in request_params.items() if value is not None}
|
||||
|
||||
# cast everything to json dict
|
||||
json_params = self.cast_value_to_json_dict(request_params)
|
||||
|
||||
if stream:
|
||||
return self._stream_chat_completion(json_params)
|
||||
else:
|
||||
return await self._nonstream_chat_completion(json_params)
|
||||
|
||||
async def _nonstream_chat_completion(self, json_params: dict[str, Any]) -> ChatCompletionResponse:
|
||||
client = self._get_client()
|
||||
response = await client.inference.chat_completion(**json_params)
|
||||
|
||||
return ChatCompletionResponse(
|
||||
completion_message=CompletionMessage(
|
||||
content=response.completion_message.content.text,
|
||||
stop_reason=response.completion_message.stop_reason,
|
||||
tool_calls=response.completion_message.tool_calls,
|
||||
),
|
||||
logprobs=response.logprobs,
|
||||
)
|
||||
|
||||
async def _stream_chat_completion(self, json_params: dict[str, Any]) -> AsyncGenerator:
|
||||
client = self._get_client()
|
||||
stream_response = await client.inference.chat_completion(**json_params)
|
||||
|
||||
async for chunk in stream_response:
|
||||
chunk = chunk.to_dict()
|
||||
|
||||
# temporary hack to remove the metrics from the response
|
||||
chunk["metrics"] = []
|
||||
chunk = convert_to_pydantic(ChatCompletionResponseStreamChunk, chunk)
|
||||
yield chunk
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
model: str,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue