From 3ab672dcda1196672202534abdef30a384e112c6 Mon Sep 17 00:00:00 2001 From: Honglin Cao Date: Wed, 12 Mar 2025 22:00:44 -0400 Subject: [PATCH] fix endpoints --- .../remote/inference/centml/centml.py | 74 +++++++------------ 1 file changed, 28 insertions(+), 46 deletions(-) diff --git a/llama_stack/providers/remote/inference/centml/centml.py b/llama_stack/providers/remote/inference/centml/centml.py index 0ca9cd54d..bf38b0387 100644 --- a/llama_stack/providers/remote/inference/centml/centml.py +++ b/llama_stack/providers/remote/inference/centml/centml.py @@ -140,58 +140,40 @@ class CentMLInferenceAdapter(ModelRegistryHelper, Inference, return await self._nonstream_completion(request) async def _nonstream_completion( - self, request: CompletionRequest) -> ChatCompletionResponse: + self, request: CompletionRequest) -> CompletionResponse: + """ + Process non-streaming completion requests. + + If a structured output is specified (e.g. JSON schema), + the adapter calls the chat completions endpoint and then + converts the chat response into a plain CompletionResponse. + Otherwise, it uses the regular completions endpoint. + """ params = await self._get_params(request) if request.response_format is not None: - # For structured output, use the chat completions endpoint. + # Use the chat completions endpoint for structured output. response = self._get_client().chat.completions.create(**params) - try: - result = process_chat_completion_response(response, request) - except KeyError as e: - if str(e) == "'parameters'": - # CentML's structured output may not include a tool call. - # Use the raw message content as the structured JSON. - raw_content = response.choices[0].message.content - message_obj = parse_obj_as( - Message, { - "role": "assistant", - "content": raw_content, - "stop_reason": "end_of_message" - }) - result = ChatCompletionResponse( - completion_message=message_obj, - logprobs=None, - ) - else: - raise - # If the processed content is still None, use the raw API content. - if result.completion_message.content is None: - raw_content = response.choices[0].message.content - if isinstance(result.completion_message, dict): - result.completion_message["content"] = raw_content - else: - updated_msg = result.completion_message.copy( - update={"content": raw_content}) - result = result.copy( - update={"completion_message": updated_msg}) + choice = response.choices[0] + message = choice.message + content = message.content if not isinstance( + message.content, list) else "".join(message.content) + return CompletionResponse( + content=content, + stop_reason= + "end_of_message", # hard code for now. need to fix later. + logprobs=None, + ) else: + # Use the completions endpoint with a prompt. + prompt_str = await completion_request_to_prompt(request) + if "messages" in params: + del params["messages"] + params["prompt"] = prompt_str response = self._get_client().completions.create(**params) result = process_completion_response(response) - # If structured output returns token lists, join them. - if request.response_format is not None: - if isinstance(result.completion_message, dict): - content = result.completion_message.get("content") - if isinstance(content, list): - result.completion_message["content"] = "".join(content) - else: - if isinstance(result.completion_message.content, list): - updated_msg = result.completion_message.copy(update={ - "content": - "".join(result.completion_message.content) - }) - result = result.copy( - update={"completion_message": updated_msg}) - return result + if isinstance(result.content, list): + result.content = "".join(result.content) + return result async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator: