fix endpoints

This commit is contained in:
Honglin Cao 2025-03-12 22:00:44 -04:00
parent 7c40048470
commit 3ab672dcda

View file

@ -140,57 +140,39 @@ class CentMLInferenceAdapter(ModelRegistryHelper, Inference,
return await self._nonstream_completion(request)
async def _nonstream_completion(
self, request: CompletionRequest) -> ChatCompletionResponse:
self, request: CompletionRequest) -> CompletionResponse:
"""
Process non-streaming completion requests.
If a structured output is specified (e.g. JSON schema),
the adapter calls the chat completions endpoint and then
converts the chat response into a plain CompletionResponse.
Otherwise, it uses the regular completions endpoint.
"""
params = await self._get_params(request)
if request.response_format is not None:
# For structured output, use the chat completions endpoint.
# Use the chat completions endpoint for structured output.
response = self._get_client().chat.completions.create(**params)
try:
result = process_chat_completion_response(response, request)
except KeyError as e:
if str(e) == "'parameters'":
# CentML's structured output may not include a tool call.
# Use the raw message content as the structured JSON.
raw_content = response.choices[0].message.content
message_obj = parse_obj_as(
Message, {
"role": "assistant",
"content": raw_content,
"stop_reason": "end_of_message"
})
result = ChatCompletionResponse(
completion_message=message_obj,
choice = response.choices[0]
message = choice.message
content = message.content if not isinstance(
message.content, list) else "".join(message.content)
return CompletionResponse(
content=content,
stop_reason=
"end_of_message", # hard code for now. need to fix later.
logprobs=None,
)
else:
raise
# If the processed content is still None, use the raw API content.
if result.completion_message.content is None:
raw_content = response.choices[0].message.content
if isinstance(result.completion_message, dict):
result.completion_message["content"] = raw_content
else:
updated_msg = result.completion_message.copy(
update={"content": raw_content})
result = result.copy(
update={"completion_message": updated_msg})
else:
# Use the completions endpoint with a prompt.
prompt_str = await completion_request_to_prompt(request)
if "messages" in params:
del params["messages"]
params["prompt"] = prompt_str
response = self._get_client().completions.create(**params)
result = process_completion_response(response)
# If structured output returns token lists, join them.
if request.response_format is not None:
if isinstance(result.completion_message, dict):
content = result.completion_message.get("content")
if isinstance(content, list):
result.completion_message["content"] = "".join(content)
else:
if isinstance(result.completion_message.content, list):
updated_msg = result.completion_message.copy(update={
"content":
"".join(result.completion_message.content)
})
result = result.copy(
update={"completion_message": updated_msg})
if isinstance(result.content, list):
result.content = "".join(result.content)
return result
async def _stream_completion(self,