fix endpoints

This commit is contained in:
Honglin Cao 2025-03-12 22:00:44 -04:00
parent 7c40048470
commit 3ab672dcda

View file

@ -140,58 +140,40 @@ class CentMLInferenceAdapter(ModelRegistryHelper, Inference,
return await self._nonstream_completion(request) return await self._nonstream_completion(request)
async def _nonstream_completion( async def _nonstream_completion(
self, request: CompletionRequest) -> ChatCompletionResponse: self, request: CompletionRequest) -> CompletionResponse:
"""
Process non-streaming completion requests.
If a structured output is specified (e.g. JSON schema),
the adapter calls the chat completions endpoint and then
converts the chat response into a plain CompletionResponse.
Otherwise, it uses the regular completions endpoint.
"""
params = await self._get_params(request) params = await self._get_params(request)
if request.response_format is not None: if request.response_format is not None:
# For structured output, use the chat completions endpoint. # Use the chat completions endpoint for structured output.
response = self._get_client().chat.completions.create(**params) response = self._get_client().chat.completions.create(**params)
try: choice = response.choices[0]
result = process_chat_completion_response(response, request) message = choice.message
except KeyError as e: content = message.content if not isinstance(
if str(e) == "'parameters'": message.content, list) else "".join(message.content)
# CentML's structured output may not include a tool call. return CompletionResponse(
# Use the raw message content as the structured JSON. content=content,
raw_content = response.choices[0].message.content stop_reason=
message_obj = parse_obj_as( "end_of_message", # hard code for now. need to fix later.
Message, { logprobs=None,
"role": "assistant", )
"content": raw_content,
"stop_reason": "end_of_message"
})
result = ChatCompletionResponse(
completion_message=message_obj,
logprobs=None,
)
else:
raise
# If the processed content is still None, use the raw API content.
if result.completion_message.content is None:
raw_content = response.choices[0].message.content
if isinstance(result.completion_message, dict):
result.completion_message["content"] = raw_content
else:
updated_msg = result.completion_message.copy(
update={"content": raw_content})
result = result.copy(
update={"completion_message": updated_msg})
else: else:
# Use the completions endpoint with a prompt.
prompt_str = await completion_request_to_prompt(request)
if "messages" in params:
del params["messages"]
params["prompt"] = prompt_str
response = self._get_client().completions.create(**params) response = self._get_client().completions.create(**params)
result = process_completion_response(response) result = process_completion_response(response)
# If structured output returns token lists, join them. if isinstance(result.content, list):
if request.response_format is not None: result.content = "".join(result.content)
if isinstance(result.completion_message, dict): return result
content = result.completion_message.get("content")
if isinstance(content, list):
result.completion_message["content"] = "".join(content)
else:
if isinstance(result.completion_message.content, list):
updated_msg = result.completion_message.copy(update={
"content":
"".join(result.completion_message.content)
})
result = result.copy(
update={"completion_message": updated_msg})
return result
async def _stream_completion(self, async def _stream_completion(self,
request: CompletionRequest) -> AsyncGenerator: request: CompletionRequest) -> AsyncGenerator: