mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-10 04:08:31 +00:00
fix endpoints
This commit is contained in:
parent
7c40048470
commit
3ab672dcda
1 changed files with 28 additions and 46 deletions
|
@ -140,58 +140,40 @@ class CentMLInferenceAdapter(ModelRegistryHelper, Inference,
|
||||||
return await self._nonstream_completion(request)
|
return await self._nonstream_completion(request)
|
||||||
|
|
||||||
async def _nonstream_completion(
|
async def _nonstream_completion(
|
||||||
self, request: CompletionRequest) -> ChatCompletionResponse:
|
self, request: CompletionRequest) -> CompletionResponse:
|
||||||
|
"""
|
||||||
|
Process non-streaming completion requests.
|
||||||
|
|
||||||
|
If a structured output is specified (e.g. JSON schema),
|
||||||
|
the adapter calls the chat completions endpoint and then
|
||||||
|
converts the chat response into a plain CompletionResponse.
|
||||||
|
Otherwise, it uses the regular completions endpoint.
|
||||||
|
"""
|
||||||
params = await self._get_params(request)
|
params = await self._get_params(request)
|
||||||
if request.response_format is not None:
|
if request.response_format is not None:
|
||||||
# For structured output, use the chat completions endpoint.
|
# Use the chat completions endpoint for structured output.
|
||||||
response = self._get_client().chat.completions.create(**params)
|
response = self._get_client().chat.completions.create(**params)
|
||||||
try:
|
choice = response.choices[0]
|
||||||
result = process_chat_completion_response(response, request)
|
message = choice.message
|
||||||
except KeyError as e:
|
content = message.content if not isinstance(
|
||||||
if str(e) == "'parameters'":
|
message.content, list) else "".join(message.content)
|
||||||
# CentML's structured output may not include a tool call.
|
return CompletionResponse(
|
||||||
# Use the raw message content as the structured JSON.
|
content=content,
|
||||||
raw_content = response.choices[0].message.content
|
stop_reason=
|
||||||
message_obj = parse_obj_as(
|
"end_of_message", # hard code for now. need to fix later.
|
||||||
Message, {
|
logprobs=None,
|
||||||
"role": "assistant",
|
)
|
||||||
"content": raw_content,
|
|
||||||
"stop_reason": "end_of_message"
|
|
||||||
})
|
|
||||||
result = ChatCompletionResponse(
|
|
||||||
completion_message=message_obj,
|
|
||||||
logprobs=None,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise
|
|
||||||
# If the processed content is still None, use the raw API content.
|
|
||||||
if result.completion_message.content is None:
|
|
||||||
raw_content = response.choices[0].message.content
|
|
||||||
if isinstance(result.completion_message, dict):
|
|
||||||
result.completion_message["content"] = raw_content
|
|
||||||
else:
|
|
||||||
updated_msg = result.completion_message.copy(
|
|
||||||
update={"content": raw_content})
|
|
||||||
result = result.copy(
|
|
||||||
update={"completion_message": updated_msg})
|
|
||||||
else:
|
else:
|
||||||
|
# Use the completions endpoint with a prompt.
|
||||||
|
prompt_str = await completion_request_to_prompt(request)
|
||||||
|
if "messages" in params:
|
||||||
|
del params["messages"]
|
||||||
|
params["prompt"] = prompt_str
|
||||||
response = self._get_client().completions.create(**params)
|
response = self._get_client().completions.create(**params)
|
||||||
result = process_completion_response(response)
|
result = process_completion_response(response)
|
||||||
# If structured output returns token lists, join them.
|
if isinstance(result.content, list):
|
||||||
if request.response_format is not None:
|
result.content = "".join(result.content)
|
||||||
if isinstance(result.completion_message, dict):
|
return result
|
||||||
content = result.completion_message.get("content")
|
|
||||||
if isinstance(content, list):
|
|
||||||
result.completion_message["content"] = "".join(content)
|
|
||||||
else:
|
|
||||||
if isinstance(result.completion_message.content, list):
|
|
||||||
updated_msg = result.completion_message.copy(update={
|
|
||||||
"content":
|
|
||||||
"".join(result.completion_message.content)
|
|
||||||
})
|
|
||||||
result = result.copy(
|
|
||||||
update={"completion_message": updated_msg})
|
|
||||||
return result
|
|
||||||
|
|
||||||
async def _stream_completion(self,
|
async def _stream_completion(self,
|
||||||
request: CompletionRequest) -> AsyncGenerator:
|
request: CompletionRequest) -> AsyncGenerator:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue