diff --git a/llama_stack/providers/remote/inference/centml/centml.py b/llama_stack/providers/remote/inference/centml/centml.py index bf38b0387..3d51efc6f 100644 --- a/llama_stack/providers/remote/inference/centml/centml.py +++ b/llama_stack/providers/remote/inference/centml/centml.py @@ -1,4 +1,3 @@ -# centml.py (updated) # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # @@ -55,7 +54,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import ( from .config import CentMLImplConfig -# Example model aliases that map from CentML’s published model identifiers +# Update this if list of model changes. MODEL_ALIASES = [ build_model_entry( "meta-llama/Llama-3.2-3B-Instruct", @@ -151,26 +150,32 @@ class CentMLInferenceAdapter(ModelRegistryHelper, Inference, """ params = await self._get_params(request) if request.response_format is not None: - # Use the chat completions endpoint for structured output. + # ***** HACK: Use the chat completions endpoint even for non-chat completions + # This is necessary because CentML's structured output (JSON schema) support + # is only available via the chat API. However, our API expects a CompletionResponse. response = self._get_client().chat.completions.create(**params) choice = response.choices[0] message = choice.message + # If message.content is returned as a list of tokens, join them into a string. content = message.content if not isinstance( message.content, list) else "".join(message.content) return CompletionResponse( content=content, stop_reason= - "end_of_message", # hard code for now. need to fix later. + "end_of_message", # ***** HACK: Hard-coded stop_reason because the chat API doesn't return one. logprobs=None, ) else: - # Use the completions endpoint with a prompt. + # ***** HACK: For non-structured outputs, ensure we use the completions endpoint. + # _get_params may include a "messages" key due to our unified parameter builder. + # We remove "messages" and instead set a "prompt" since the completions endpoint expects it. prompt_str = await completion_request_to_prompt(request) if "messages" in params: del params["messages"] params["prompt"] = prompt_str response = self._get_client().completions.create(**params) result = process_completion_response(response) + # Join tokenized content if needed. if isinstance(result.content, list): result.content = "".join(result.content) return result @@ -180,6 +185,8 @@ class CentMLInferenceAdapter(ModelRegistryHelper, Inference, params = await self._get_params(request) async def _to_async_generator(): + # ***** HACK: For streaming structured outputs, use the chat completions endpoint. + # Otherwise, use the regular completions endpoint. if request.response_format is not None: stream = self._get_client().chat.completions.create(**params) else: @@ -236,11 +243,14 @@ class CentMLInferenceAdapter(ModelRegistryHelper, Inference, async def _nonstream_chat_completion( self, request: ChatCompletionRequest) -> ChatCompletionResponse: params = await self._get_params(request) + # Use the chat completions endpoint if "messages" key is present. if "messages" in params: response = self._get_client().chat.completions.create(**params) else: response = self._get_client().completions.create(**params) result = process_chat_completion_response(response, request) + # ***** HACK: Sometimes the returned content is tokenized as a list. + # We join the tokens into a single string to produce a unified output. if request.response_format is not None: if isinstance(result.completion_message, dict): content = result.completion_message.get("content") @@ -261,6 +271,7 @@ class CentMLInferenceAdapter(ModelRegistryHelper, Inference, params = await self._get_params(request) async def _to_async_generator(): + # ***** HACK: Use the chat completions endpoint if "messages" key is present. if "messages" in params: stream = self._get_client().chat.completions.create(**params) else: @@ -280,6 +291,11 @@ class CentMLInferenceAdapter(ModelRegistryHelper, Inference, async def _get_params( self, request: Union[ChatCompletionRequest, CompletionRequest]) -> dict: + """ + Build a unified set of parameters for both chat and non-chat requests. + When a structured output is specified (response_format is not None), we force + the use of a "messages" array even for CompletionRequests. + """ input_dict = {} media_present = request_has_media(request) llama_model = self.get_llama_model(request.model) @@ -290,6 +306,8 @@ class CentMLInferenceAdapter(ModelRegistryHelper, Inference, for m in request.messages ] else: + # ***** HACK: For CompletionRequests with structured output, + # we simulate a chat conversation by wrapping the prompt as a single user message. prompt_str = await completion_request_to_prompt(request) input_dict["messages"] = [{ "role": "user", @@ -325,6 +343,10 @@ class CentMLInferenceAdapter(ModelRegistryHelper, Inference, logprobs: Optional[LogProbConfig], fmt: Optional[ResponseFormat], ) -> dict: + """ + Build additional options such as sampling parameters and logprobs. + Also translates our response_format into the format expected by CentML's API. + """ options = get_sampling_options(sampling_params) if fmt: if fmt.type == ResponseFormatType.json_schema.value: @@ -356,8 +378,8 @@ class CentMLInferenceAdapter(ModelRegistryHelper, Inference, output_dimension: Optional[int], contents: List[InterleavedContent], ) -> EmbeddingsResponse: + # this will come in future updates model = await self.model_store.get_model(model_id) - # CentML does not support media for embeddings. assert all(not content_has_media(c) for c in contents), ( "CentML does not support media for embeddings") resp = self._get_client().embeddings.create(