From 102af46d5d2915e373cc6da3097753babbd56c0f Mon Sep 17 00:00:00 2001 From: Honglin Cao Date: Tue, 4 Feb 2025 16:22:54 -0500 Subject: [PATCH] fixing centml get params --- .../remote/inference/centml/centml.py | 36 ++++++++++++------- 1 file changed, 24 insertions(+), 12 deletions(-) diff --git a/llama_stack/providers/remote/inference/centml/centml.py b/llama_stack/providers/remote/inference/centml/centml.py index aacc73804..c3798837b 100644 --- a/llama_stack/providers/remote/inference/centml/centml.py +++ b/llama_stack/providers/remote/inference/centml/centml.py @@ -237,6 +237,7 @@ class CentMLInferenceAdapter( async def _get_params( self, request: Union[ChatCompletionRequest, CompletionRequest] ) -> dict: + """ Build the 'params' dict that the OpenAI (CentML) client expects. For chat requests, we always prefer "messages" so that it calls @@ -251,25 +252,36 @@ class CentMLInferenceAdapter( await convert_message_to_openai_dict(m) for m in request.messages ] - else: # Non-chat (CompletionRequest) - assert not media_present, ( - "CentML does not support media for completions" - ) + assert not media_present, "CentML does not support media for completions" input_dict["prompt"] = await completion_request_to_prompt( - request, self.formatter - ) + request, self.formatter) - return { - "model": request.model, + params = { + "model": + request.model, **input_dict, - "stream": request.stream, - **self._build_options( - request.sampling_params, request.response_format - ), + "stream": + request.stream, + **self._build_options(request.sampling_params, request.response_format), } + # For non-chat completions (i.e. when using a "prompt"), CentML's + # completions endpoint does not support the response_format parameter. + if "prompt" in params and "response_format" in params: + del params["response_format"] + + # For chat completions with structured output, CentML requires + # guided decoding settings to use num_scheduler_steps=1 and spec_enabled=False. + # Override these if a response_format was requested. + if "messages" in params and request.response_format: + params["num_scheduler_steps"] = 1 + params["spec_enabled"] = False + + return params + + def _build_options( self, sampling_params: Optional[SamplingParams],