From e31a52b26ee19471338685b1b6a909e74126c9a4 Mon Sep 17 00:00:00 2001 From: Honglin Cao Date: Tue, 4 Mar 2025 18:26:08 -0500 Subject: [PATCH] fix endpoint --- .../remote/inference/centml/centml.py | 90 +++++++++---------- 1 file changed, 44 insertions(+), 46 deletions(-) diff --git a/llama_stack/providers/remote/inference/centml/centml.py b/llama_stack/providers/remote/inference/centml/centml.py index 5f14c23f1..02b4df475 100644 --- a/llama_stack/providers/remote/inference/centml/centml.py +++ b/llama_stack/providers/remote/inference/centml/centml.py @@ -8,6 +8,7 @@ from typing import AsyncGenerator, List, Optional, Union from openai import OpenAI +from llama_stack import logcat from llama_models.datatypes import CoreModelId from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.tokenizer import Tokenizer @@ -75,7 +76,6 @@ class CentMLInferenceAdapter( def __init__(self, config: CentMLImplConfig) -> None: super().__init__(MODEL_ALIASES) self.config = config - self.formatter = ChatFormat(Tokenizer.get_instance()) async def initialize(self) -> None: pass @@ -127,7 +127,8 @@ class CentMLInferenceAdapter( model=model.provider_resource_id, content=content, sampling_params=sampling_params, - response_format=response_format, + # Completions.create() got an unexpected keyword argument 'response_format' + #response_format=response_format, stream=stream, logprobs=logprobs, ) @@ -142,7 +143,7 @@ class CentMLInferenceAdapter( params = await self._get_params(request) # Using the older "completions" route for non-chat response = self._get_client().completions.create(**params) - return process_completion_response(response, self.formatter) + return process_completion_response(response) async def _stream_completion( self, request: CompletionRequest @@ -156,7 +157,7 @@ class CentMLInferenceAdapter( stream = _to_async_generator() async for chunk in process_completion_stream_response( - stream, self.formatter + stream ): yield chunk @@ -188,7 +189,8 @@ class CentMLInferenceAdapter( tools=tools or [], tool_choice=tool_choice, tool_prompt_format=tool_prompt_format, - response_format=response_format, + # Completions.create() got an unexpected keyword argument 'response_format' + #response_format=response_format, stream=stream, logprobs=logprobs, ) @@ -209,7 +211,7 @@ class CentMLInferenceAdapter( # fallback if we ended up only with "prompt" response = self._get_client().completions.create(**params) - return process_chat_completion_response(response, self.formatter) + return process_chat_completion_response(response, request) async def _stream_chat_completion( self, request: ChatCompletionRequest @@ -226,62 +228,34 @@ class CentMLInferenceAdapter( stream = _to_async_generator() async for chunk in process_chat_completion_stream_response( - stream, self.formatter - ): + stream, request): yield chunk # # HELPER METHODS # - async def _get_params( - self, request: Union[ChatCompletionRequest, CompletionRequest] - ) -> dict: - - """ - Build the 'params' dict that the OpenAI (CentML) client expects. - For chat requests, we always prefer "messages" so that it calls - the chat endpoint properly. - """ + async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequest]) -> dict: input_dict = {} media_present = request_has_media(request) - + llama_model = self.get_llama_model(request.model) if isinstance(request, ChatCompletionRequest): - # For chat requests, always build "messages" from the user messages - input_dict["messages"] = [ - await convert_message_to_openai_dict(m) - for m in request.messages - ] + if media_present or not llama_model: + input_dict["messages"] = [await convert_message_to_openai_dict(m) for m in request.messages] + else: + input_dict["prompt"] = await chat_completion_request_to_prompt(request, llama_model) else: - # Non-chat (CompletionRequest) - assert not media_present, "CentML does not support media for completions" - input_dict["prompt"] = await completion_request_to_prompt( - request, self.formatter) + input_dict["prompt"] = await completion_request_to_prompt(request) params = { - "model": - request.model, + "model": request.model, **input_dict, - "stream": - request.stream, - **self._build_options(request.sampling_params, request.response_format), + "stream": request.stream, + **self._build_options(request.sampling_params, request.logprobs, request.response_format), } - - # For non-chat completions (i.e. when using a "prompt"), CentML's - # completions endpoint does not support the response_format parameter. - if "prompt" in params and "response_format" in params: - del params["response_format"] - - # For chat completions with structured output, CentML requires - # guided decoding settings to use num_scheduler_steps=1 and spec_enabled=False. - # Override these if a response_format was requested. - if "messages" in params and request.response_format: - params["num_scheduler_steps"] = 1 - params["spec_enabled"] = False - + logcat.debug("inference", f"params to centml: {params}") return params - def _build_options( self, sampling_params: Optional[SamplingParams], @@ -308,6 +282,30 @@ class CentMLInferenceAdapter( return options + def _build_options( + self, + sampling_params: Optional[SamplingParams], + logprobs: Optional[LogProbConfig], + fmt: ResponseFormat, + ) -> dict: + options = get_sampling_options(sampling_params) + if fmt: + if fmt.type == ResponseFormatType.json_schema.value: + options["response_format"] = { + "type": "json_object", + "schema": fmt.json_schema, + } + elif fmt.type == ResponseFormatType.grammar.value: + raise NotImplementedError( + "Grammar response format not supported yet") + else: + raise ValueError(f"Unknown response format {fmt.type}") + + if logprobs and logprobs.top_k: + options["logprobs"] = 1 + + return options + # # EMBEDDINGS #