fixing centml get params

This commit is contained in:
Honglin Cao 2025-02-04 16:22:54 -05:00
parent 6c1b1722b4
commit 102af46d5d

View file

@ -237,6 +237,7 @@ class CentMLInferenceAdapter(
async def _get_params( async def _get_params(
self, request: Union[ChatCompletionRequest, CompletionRequest] self, request: Union[ChatCompletionRequest, CompletionRequest]
) -> dict: ) -> dict:
""" """
Build the 'params' dict that the OpenAI (CentML) client expects. Build the 'params' dict that the OpenAI (CentML) client expects.
For chat requests, we always prefer "messages" so that it calls For chat requests, we always prefer "messages" so that it calls
@ -251,25 +252,36 @@ class CentMLInferenceAdapter(
await convert_message_to_openai_dict(m) await convert_message_to_openai_dict(m)
for m in request.messages for m in request.messages
] ]
else: else:
# Non-chat (CompletionRequest) # Non-chat (CompletionRequest)
assert not media_present, ( assert not media_present, "CentML does not support media for completions"
"CentML does not support media for completions"
)
input_dict["prompt"] = await completion_request_to_prompt( input_dict["prompt"] = await completion_request_to_prompt(
request, self.formatter request, self.formatter)
)
return { params = {
"model": request.model, "model":
request.model,
**input_dict, **input_dict,
"stream": request.stream, "stream":
**self._build_options( request.stream,
request.sampling_params, request.response_format **self._build_options(request.sampling_params, request.response_format),
),
} }
# For non-chat completions (i.e. when using a "prompt"), CentML's
# completions endpoint does not support the response_format parameter.
if "prompt" in params and "response_format" in params:
del params["response_format"]
# For chat completions with structured output, CentML requires
# guided decoding settings to use num_scheduler_steps=1 and spec_enabled=False.
# Override these if a response_format was requested.
if "messages" in params and request.response_format:
params["num_scheduler_steps"] = 1
params["spec_enabled"] = False
return params
def _build_options( def _build_options(
self, self,
sampling_params: Optional[SamplingParams], sampling_params: Optional[SamplingParams],