change to adapter

This commit is contained in:
Honglin Cao 2025-03-11 15:10:31 -04:00
parent a454b53bda
commit e2290a0096

View file

@ -1,3 +1,4 @@
# centml.py (updated)
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
@ -7,6 +8,7 @@
from typing import AsyncGenerator, List, Optional, Union
from openai import OpenAI
from pydantic import parse_obj_as
from llama_models.datatypes import CoreModelId
from llama_models.llama3.api.chat_format import ChatFormat
@ -53,8 +55,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
from .config import CentMLImplConfig
# Example model aliases that map from CentMLs
# published model identifiers to llama-stack's `CoreModelId`.
# Example model aliases that map from CentMLs published model identifiers
MODEL_ALIASES = [
build_model_entry(
"meta-llama/Llama-3.2-3B-Instruct",
@ -129,8 +130,7 @@ class CentMLInferenceAdapter(ModelRegistryHelper, Inference,
model=model.provider_resource_id,
content=content,
sampling_params=sampling_params,
# Completions.create() got an unexpected keyword argument 'response_format'
#response_format=response_format,
response_format=response_format,
stream=stream,
logprobs=logprobs,
)
@ -142,22 +142,77 @@ class CentMLInferenceAdapter(ModelRegistryHelper, Inference,
async def _nonstream_completion(
self, request: CompletionRequest) -> ChatCompletionResponse:
params = await self._get_params(request)
# Using the older "completions" route for non-chat
response = self._get_client().completions.create(**params)
return process_completion_response(response)
if request.response_format is not None:
# For structured output, use the chat completions endpoint.
response = self._get_client().chat.completions.create(**params)
try:
result = process_chat_completion_response(response, request)
except KeyError as e:
if str(e) == "'parameters'":
# CentML's structured output may not include a tool call.
# Use the raw message content as the structured JSON.
raw_content = response.choices[0].message.content
message_obj = parse_obj_as(
Message, {
"role": "assistant",
"content": raw_content,
"stop_reason": "end_of_message"
})
result = ChatCompletionResponse(
completion_message=message_obj,
logprobs=None,
)
else:
raise
# If the processed content is still None, use the raw API content.
if result.completion_message.content is None:
raw_content = response.choices[0].message.content
if isinstance(result.completion_message, dict):
result.completion_message["content"] = raw_content
else:
updated_msg = result.completion_message.copy(
update={"content": raw_content})
result = result.copy(
update={"completion_message": updated_msg})
else:
response = self._get_client().completions.create(**params)
result = process_completion_response(response)
# If structured output returns token lists, join them.
if request.response_format is not None:
if isinstance(result.completion_message, dict):
content = result.completion_message.get("content")
if isinstance(content, list):
result.completion_message["content"] = "".join(content)
else:
if isinstance(result.completion_message.content, list):
updated_msg = result.completion_message.copy(update={
"content":
"".join(result.completion_message.content)
})
result = result.copy(
update={"completion_message": updated_msg})
return result
async def _stream_completion(self,
request: CompletionRequest) -> AsyncGenerator:
params = await self._get_params(request)
async def _to_async_generator():
stream = self._get_client().completions.create(**params)
if request.response_format is not None:
stream = self._get_client().chat.completions.create(**params)
else:
stream = self._get_client().completions.create(**params)
for chunk in stream:
yield chunk
stream = _to_async_generator()
async for chunk in process_completion_stream_response(stream):
yield chunk
if request.response_format is not None:
async for chunk in process_chat_completion_stream_response(
stream, request):
yield chunk
else:
async for chunk in process_completion_stream_response(stream):
yield chunk
#
# CHAT COMPLETION
@ -187,8 +242,7 @@ class CentMLInferenceAdapter(ModelRegistryHelper, Inference,
tools=tools or [],
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
# Completions.create() got an unexpected keyword argument 'response_format'
#response_format=response_format,
response_format=response_format,
stream=stream,
logprobs=logprobs,
)
@ -200,15 +254,25 @@ class CentMLInferenceAdapter(ModelRegistryHelper, Inference,
async def _nonstream_chat_completion(
self, request: ChatCompletionRequest) -> ChatCompletionResponse:
params = await self._get_params(request)
# For chat requests, if "messages" is in params -> .chat.completions
if "messages" in params:
response = self._get_client().chat.completions.create(**params)
else:
# fallback if we ended up only with "prompt"
response = self._get_client().completions.create(**params)
return process_chat_completion_response(response, request)
result = process_chat_completion_response(response, request)
if request.response_format is not None:
if isinstance(result.completion_message, dict):
content = result.completion_message.get("content")
if isinstance(content, list):
result.completion_message["content"] = "".join(content)
else:
if isinstance(result.completion_message.content, list):
updated_msg = result.completion_message.copy(update={
"content":
"".join(result.completion_message.content)
})
result = result.copy(
update={"completion_message": updated_msg})
return result
async def _stream_chat_completion(
self, request: ChatCompletionRequest) -> AsyncGenerator:
@ -237,18 +301,32 @@ class CentMLInferenceAdapter(ModelRegistryHelper, Inference,
input_dict = {}
media_present = request_has_media(request)
llama_model = self.get_llama_model(request.model)
if isinstance(request, ChatCompletionRequest):
if media_present or not llama_model:
if request.response_format is not None:
if isinstance(request, ChatCompletionRequest):
input_dict["messages"] = [
await convert_message_to_openai_dict(m)
for m in request.messages
]
else:
input_dict["prompt"] = await chat_completion_request_to_prompt(
request, llama_model)
prompt_str = await completion_request_to_prompt(request)
input_dict["messages"] = [{
"role": "user",
"content": prompt_str
}]
else:
input_dict["prompt"] = await completion_request_to_prompt(request)
if isinstance(request, ChatCompletionRequest):
if media_present or not llama_model:
input_dict["messages"] = [
await convert_message_to_openai_dict(m)
for m in request.messages
]
else:
input_dict[
"prompt"] = await chat_completion_request_to_prompt(
request, llama_model)
else:
input_dict["prompt"] = await completion_request_to_prompt(
request)
params = {
"model":
request.model,
@ -263,26 +341,25 @@ class CentMLInferenceAdapter(ModelRegistryHelper, Inference,
self,
sampling_params: Optional[SamplingParams],
logprobs: Optional[LogProbConfig],
fmt: ResponseFormat,
fmt: Optional[ResponseFormat],
) -> dict:
options = get_sampling_options(sampling_params)
if fmt:
if fmt.type == ResponseFormatType.json_schema.value:
options["response_format"] = {
"type": "json_object",
# CentML currently does not support guided decoding,
# the following setting is currently ignored by the server.
"schema": fmt.json_schema,
"type": "json_schema",
"json_schema": {
"name": "schema",
"schema": fmt.json_schema
},
}
elif fmt.type == ResponseFormatType.grammar.value:
raise NotImplementedError(
"Grammar response format not supported yet")
else:
raise ValueError(f"Unknown response format {fmt.type}")
if logprobs and logprobs.top_k:
options["logprobs"] = logprobs.top_k
return options
#
@ -301,7 +378,6 @@ class CentMLInferenceAdapter(ModelRegistryHelper, Inference,
# CentML does not support media for embeddings.
assert all(not content_has_media(c) for c in contents), (
"CentML does not support media for embeddings")
resp = self._get_client().embeddings.create(
model=model.provider_resource_id,
input=[interleaved_content_as_str(c) for c in contents],