fix endpoint

This commit is contained in:
Honglin Cao 2025-03-04 18:26:08 -05:00
parent 98549b826d
commit e31a52b26e

View file

@ -8,6 +8,7 @@ from typing import AsyncGenerator, List, Optional, Union
from openai import OpenAI
from llama_stack import logcat
from llama_models.datatypes import CoreModelId
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.tokenizer import Tokenizer
@ -75,7 +76,6 @@ class CentMLInferenceAdapter(
def __init__(self, config: CentMLImplConfig) -> None:
super().__init__(MODEL_ALIASES)
self.config = config
self.formatter = ChatFormat(Tokenizer.get_instance())
async def initialize(self) -> None:
pass
@ -127,7 +127,8 @@ class CentMLInferenceAdapter(
model=model.provider_resource_id,
content=content,
sampling_params=sampling_params,
response_format=response_format,
# Completions.create() got an unexpected keyword argument 'response_format'
#response_format=response_format,
stream=stream,
logprobs=logprobs,
)
@ -142,7 +143,7 @@ class CentMLInferenceAdapter(
params = await self._get_params(request)
# Using the older "completions" route for non-chat
response = self._get_client().completions.create(**params)
return process_completion_response(response, self.formatter)
return process_completion_response(response)
async def _stream_completion(
self, request: CompletionRequest
@ -156,7 +157,7 @@ class CentMLInferenceAdapter(
stream = _to_async_generator()
async for chunk in process_completion_stream_response(
stream, self.formatter
stream
):
yield chunk
@ -188,7 +189,8 @@ class CentMLInferenceAdapter(
tools=tools or [],
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
response_format=response_format,
# Completions.create() got an unexpected keyword argument 'response_format'
#response_format=response_format,
stream=stream,
logprobs=logprobs,
)
@ -209,7 +211,7 @@ class CentMLInferenceAdapter(
# fallback if we ended up only with "prompt"
response = self._get_client().completions.create(**params)
return process_chat_completion_response(response, self.formatter)
return process_chat_completion_response(response, request)
async def _stream_chat_completion(
self, request: ChatCompletionRequest
@ -226,62 +228,34 @@ class CentMLInferenceAdapter(
stream = _to_async_generator()
async for chunk in process_chat_completion_stream_response(
stream, self.formatter
):
stream, request):
yield chunk
#
# HELPER METHODS
#
async def _get_params(
self, request: Union[ChatCompletionRequest, CompletionRequest]
) -> dict:
"""
Build the 'params' dict that the OpenAI (CentML) client expects.
For chat requests, we always prefer "messages" so that it calls
the chat endpoint properly.
"""
async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequest]) -> dict:
input_dict = {}
media_present = request_has_media(request)
llama_model = self.get_llama_model(request.model)
if isinstance(request, ChatCompletionRequest):
# For chat requests, always build "messages" from the user messages
input_dict["messages"] = [
await convert_message_to_openai_dict(m)
for m in request.messages
]
if media_present or not llama_model:
input_dict["messages"] = [await convert_message_to_openai_dict(m) for m in request.messages]
else:
# Non-chat (CompletionRequest)
assert not media_present, "CentML does not support media for completions"
input_dict["prompt"] = await completion_request_to_prompt(
request, self.formatter)
input_dict["prompt"] = await chat_completion_request_to_prompt(request, llama_model)
else:
input_dict["prompt"] = await completion_request_to_prompt(request)
params = {
"model":
request.model,
"model": request.model,
**input_dict,
"stream":
request.stream,
**self._build_options(request.sampling_params, request.response_format),
"stream": request.stream,
**self._build_options(request.sampling_params, request.logprobs, request.response_format),
}
# For non-chat completions (i.e. when using a "prompt"), CentML's
# completions endpoint does not support the response_format parameter.
if "prompt" in params and "response_format" in params:
del params["response_format"]
# For chat completions with structured output, CentML requires
# guided decoding settings to use num_scheduler_steps=1 and spec_enabled=False.
# Override these if a response_format was requested.
if "messages" in params and request.response_format:
params["num_scheduler_steps"] = 1
params["spec_enabled"] = False
logcat.debug("inference", f"params to centml: {params}")
return params
def _build_options(
self,
sampling_params: Optional[SamplingParams],
@ -308,6 +282,30 @@ class CentMLInferenceAdapter(
return options
def _build_options(
self,
sampling_params: Optional[SamplingParams],
logprobs: Optional[LogProbConfig],
fmt: ResponseFormat,
) -> dict:
options = get_sampling_options(sampling_params)
if fmt:
if fmt.type == ResponseFormatType.json_schema.value:
options["response_format"] = {
"type": "json_object",
"schema": fmt.json_schema,
}
elif fmt.type == ResponseFormatType.grammar.value:
raise NotImplementedError(
"Grammar response format not supported yet")
else:
raise ValueError(f"Unknown response format {fmt.type}")
if logprobs and logprobs.top_k:
options["logprobs"] = 1
return options
#
# EMBEDDINGS
#