add comments

This commit is contained in:
Honglin Cao 2025-03-12 22:06:08 -04:00
parent 3ab672dcda
commit 0cef9adda5

View file

@ -1,4 +1,3 @@
# centml.py (updated)
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
@ -55,7 +54,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
from .config import CentMLImplConfig
# Example model aliases that map from CentMLs published model identifiers
# Update this if list of model changes.
MODEL_ALIASES = [
build_model_entry(
"meta-llama/Llama-3.2-3B-Instruct",
@ -151,26 +150,32 @@ class CentMLInferenceAdapter(ModelRegistryHelper, Inference,
"""
params = await self._get_params(request)
if request.response_format is not None:
# Use the chat completions endpoint for structured output.
# ***** HACK: Use the chat completions endpoint even for non-chat completions
# This is necessary because CentML's structured output (JSON schema) support
# is only available via the chat API. However, our API expects a CompletionResponse.
response = self._get_client().chat.completions.create(**params)
choice = response.choices[0]
message = choice.message
# If message.content is returned as a list of tokens, join them into a string.
content = message.content if not isinstance(
message.content, list) else "".join(message.content)
return CompletionResponse(
content=content,
stop_reason=
"end_of_message", # hard code for now. need to fix later.
"end_of_message", # ***** HACK: Hard-coded stop_reason because the chat API doesn't return one.
logprobs=None,
)
else:
# Use the completions endpoint with a prompt.
# ***** HACK: For non-structured outputs, ensure we use the completions endpoint.
# _get_params may include a "messages" key due to our unified parameter builder.
# We remove "messages" and instead set a "prompt" since the completions endpoint expects it.
prompt_str = await completion_request_to_prompt(request)
if "messages" in params:
del params["messages"]
params["prompt"] = prompt_str
response = self._get_client().completions.create(**params)
result = process_completion_response(response)
# Join tokenized content if needed.
if isinstance(result.content, list):
result.content = "".join(result.content)
return result
@ -180,6 +185,8 @@ class CentMLInferenceAdapter(ModelRegistryHelper, Inference,
params = await self._get_params(request)
async def _to_async_generator():
# ***** HACK: For streaming structured outputs, use the chat completions endpoint.
# Otherwise, use the regular completions endpoint.
if request.response_format is not None:
stream = self._get_client().chat.completions.create(**params)
else:
@ -236,11 +243,14 @@ class CentMLInferenceAdapter(ModelRegistryHelper, Inference,
async def _nonstream_chat_completion(
self, request: ChatCompletionRequest) -> ChatCompletionResponse:
params = await self._get_params(request)
# Use the chat completions endpoint if "messages" key is present.
if "messages" in params:
response = self._get_client().chat.completions.create(**params)
else:
response = self._get_client().completions.create(**params)
result = process_chat_completion_response(response, request)
# ***** HACK: Sometimes the returned content is tokenized as a list.
# We join the tokens into a single string to produce a unified output.
if request.response_format is not None:
if isinstance(result.completion_message, dict):
content = result.completion_message.get("content")
@ -261,6 +271,7 @@ class CentMLInferenceAdapter(ModelRegistryHelper, Inference,
params = await self._get_params(request)
async def _to_async_generator():
# ***** HACK: Use the chat completions endpoint if "messages" key is present.
if "messages" in params:
stream = self._get_client().chat.completions.create(**params)
else:
@ -280,6 +291,11 @@ class CentMLInferenceAdapter(ModelRegistryHelper, Inference,
async def _get_params(
self, request: Union[ChatCompletionRequest,
CompletionRequest]) -> dict:
"""
Build a unified set of parameters for both chat and non-chat requests.
When a structured output is specified (response_format is not None), we force
the use of a "messages" array even for CompletionRequests.
"""
input_dict = {}
media_present = request_has_media(request)
llama_model = self.get_llama_model(request.model)
@ -290,6 +306,8 @@ class CentMLInferenceAdapter(ModelRegistryHelper, Inference,
for m in request.messages
]
else:
# ***** HACK: For CompletionRequests with structured output,
# we simulate a chat conversation by wrapping the prompt as a single user message.
prompt_str = await completion_request_to_prompt(request)
input_dict["messages"] = [{
"role": "user",
@ -325,6 +343,10 @@ class CentMLInferenceAdapter(ModelRegistryHelper, Inference,
logprobs: Optional[LogProbConfig],
fmt: Optional[ResponseFormat],
) -> dict:
"""
Build additional options such as sampling parameters and logprobs.
Also translates our response_format into the format expected by CentML's API.
"""
options = get_sampling_options(sampling_params)
if fmt:
if fmt.type == ResponseFormatType.json_schema.value:
@ -356,8 +378,8 @@ class CentMLInferenceAdapter(ModelRegistryHelper, Inference,
output_dimension: Optional[int],
contents: List[InterleavedContent],
) -> EmbeddingsResponse:
# this will come in future updates
model = await self.model_store.get_model(model_id)
# CentML does not support media for embeddings.
assert all(not content_has_media(c) for c in contents), (
"CentML does not support media for embeddings")
resp = self._get_client().embeddings.create(