mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-10 04:08:31 +00:00
fix endpoints
This commit is contained in:
parent
7c40048470
commit
3ab672dcda
1 changed files with 28 additions and 46 deletions
|
@ -140,58 +140,40 @@ class CentMLInferenceAdapter(ModelRegistryHelper, Inference,
|
|||
return await self._nonstream_completion(request)
|
||||
|
||||
async def _nonstream_completion(
|
||||
self, request: CompletionRequest) -> ChatCompletionResponse:
|
||||
self, request: CompletionRequest) -> CompletionResponse:
|
||||
"""
|
||||
Process non-streaming completion requests.
|
||||
|
||||
If a structured output is specified (e.g. JSON schema),
|
||||
the adapter calls the chat completions endpoint and then
|
||||
converts the chat response into a plain CompletionResponse.
|
||||
Otherwise, it uses the regular completions endpoint.
|
||||
"""
|
||||
params = await self._get_params(request)
|
||||
if request.response_format is not None:
|
||||
# For structured output, use the chat completions endpoint.
|
||||
# Use the chat completions endpoint for structured output.
|
||||
response = self._get_client().chat.completions.create(**params)
|
||||
try:
|
||||
result = process_chat_completion_response(response, request)
|
||||
except KeyError as e:
|
||||
if str(e) == "'parameters'":
|
||||
# CentML's structured output may not include a tool call.
|
||||
# Use the raw message content as the structured JSON.
|
||||
raw_content = response.choices[0].message.content
|
||||
message_obj = parse_obj_as(
|
||||
Message, {
|
||||
"role": "assistant",
|
||||
"content": raw_content,
|
||||
"stop_reason": "end_of_message"
|
||||
})
|
||||
result = ChatCompletionResponse(
|
||||
completion_message=message_obj,
|
||||
logprobs=None,
|
||||
)
|
||||
else:
|
||||
raise
|
||||
# If the processed content is still None, use the raw API content.
|
||||
if result.completion_message.content is None:
|
||||
raw_content = response.choices[0].message.content
|
||||
if isinstance(result.completion_message, dict):
|
||||
result.completion_message["content"] = raw_content
|
||||
else:
|
||||
updated_msg = result.completion_message.copy(
|
||||
update={"content": raw_content})
|
||||
result = result.copy(
|
||||
update={"completion_message": updated_msg})
|
||||
choice = response.choices[0]
|
||||
message = choice.message
|
||||
content = message.content if not isinstance(
|
||||
message.content, list) else "".join(message.content)
|
||||
return CompletionResponse(
|
||||
content=content,
|
||||
stop_reason=
|
||||
"end_of_message", # hard code for now. need to fix later.
|
||||
logprobs=None,
|
||||
)
|
||||
else:
|
||||
# Use the completions endpoint with a prompt.
|
||||
prompt_str = await completion_request_to_prompt(request)
|
||||
if "messages" in params:
|
||||
del params["messages"]
|
||||
params["prompt"] = prompt_str
|
||||
response = self._get_client().completions.create(**params)
|
||||
result = process_completion_response(response)
|
||||
# If structured output returns token lists, join them.
|
||||
if request.response_format is not None:
|
||||
if isinstance(result.completion_message, dict):
|
||||
content = result.completion_message.get("content")
|
||||
if isinstance(content, list):
|
||||
result.completion_message["content"] = "".join(content)
|
||||
else:
|
||||
if isinstance(result.completion_message.content, list):
|
||||
updated_msg = result.completion_message.copy(update={
|
||||
"content":
|
||||
"".join(result.completion_message.content)
|
||||
})
|
||||
result = result.copy(
|
||||
update={"completion_message": updated_msg})
|
||||
return result
|
||||
if isinstance(result.content, list):
|
||||
result.content = "".join(result.content)
|
||||
return result
|
||||
|
||||
async def _stream_completion(self,
|
||||
request: CompletionRequest) -> AsyncGenerator:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue