mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-10 04:08:31 +00:00
add comments
This commit is contained in:
parent
3ab672dcda
commit
0cef9adda5
1 changed files with 28 additions and 6 deletions
|
@ -1,4 +1,3 @@
|
|||
# centml.py (updated)
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
|
@ -55,7 +54,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
|
||||
from .config import CentMLImplConfig
|
||||
|
||||
# Example model aliases that map from CentML’s published model identifiers
|
||||
# Update this if list of model changes.
|
||||
MODEL_ALIASES = [
|
||||
build_model_entry(
|
||||
"meta-llama/Llama-3.2-3B-Instruct",
|
||||
|
@ -151,26 +150,32 @@ class CentMLInferenceAdapter(ModelRegistryHelper, Inference,
|
|||
"""
|
||||
params = await self._get_params(request)
|
||||
if request.response_format is not None:
|
||||
# Use the chat completions endpoint for structured output.
|
||||
# ***** HACK: Use the chat completions endpoint even for non-chat completions
|
||||
# This is necessary because CentML's structured output (JSON schema) support
|
||||
# is only available via the chat API. However, our API expects a CompletionResponse.
|
||||
response = self._get_client().chat.completions.create(**params)
|
||||
choice = response.choices[0]
|
||||
message = choice.message
|
||||
# If message.content is returned as a list of tokens, join them into a string.
|
||||
content = message.content if not isinstance(
|
||||
message.content, list) else "".join(message.content)
|
||||
return CompletionResponse(
|
||||
content=content,
|
||||
stop_reason=
|
||||
"end_of_message", # hard code for now. need to fix later.
|
||||
"end_of_message", # ***** HACK: Hard-coded stop_reason because the chat API doesn't return one.
|
||||
logprobs=None,
|
||||
)
|
||||
else:
|
||||
# Use the completions endpoint with a prompt.
|
||||
# ***** HACK: For non-structured outputs, ensure we use the completions endpoint.
|
||||
# _get_params may include a "messages" key due to our unified parameter builder.
|
||||
# We remove "messages" and instead set a "prompt" since the completions endpoint expects it.
|
||||
prompt_str = await completion_request_to_prompt(request)
|
||||
if "messages" in params:
|
||||
del params["messages"]
|
||||
params["prompt"] = prompt_str
|
||||
response = self._get_client().completions.create(**params)
|
||||
result = process_completion_response(response)
|
||||
# Join tokenized content if needed.
|
||||
if isinstance(result.content, list):
|
||||
result.content = "".join(result.content)
|
||||
return result
|
||||
|
@ -180,6 +185,8 @@ class CentMLInferenceAdapter(ModelRegistryHelper, Inference,
|
|||
params = await self._get_params(request)
|
||||
|
||||
async def _to_async_generator():
|
||||
# ***** HACK: For streaming structured outputs, use the chat completions endpoint.
|
||||
# Otherwise, use the regular completions endpoint.
|
||||
if request.response_format is not None:
|
||||
stream = self._get_client().chat.completions.create(**params)
|
||||
else:
|
||||
|
@ -236,11 +243,14 @@ class CentMLInferenceAdapter(ModelRegistryHelper, Inference,
|
|||
async def _nonstream_chat_completion(
|
||||
self, request: ChatCompletionRequest) -> ChatCompletionResponse:
|
||||
params = await self._get_params(request)
|
||||
# Use the chat completions endpoint if "messages" key is present.
|
||||
if "messages" in params:
|
||||
response = self._get_client().chat.completions.create(**params)
|
||||
else:
|
||||
response = self._get_client().completions.create(**params)
|
||||
result = process_chat_completion_response(response, request)
|
||||
# ***** HACK: Sometimes the returned content is tokenized as a list.
|
||||
# We join the tokens into a single string to produce a unified output.
|
||||
if request.response_format is not None:
|
||||
if isinstance(result.completion_message, dict):
|
||||
content = result.completion_message.get("content")
|
||||
|
@ -261,6 +271,7 @@ class CentMLInferenceAdapter(ModelRegistryHelper, Inference,
|
|||
params = await self._get_params(request)
|
||||
|
||||
async def _to_async_generator():
|
||||
# ***** HACK: Use the chat completions endpoint if "messages" key is present.
|
||||
if "messages" in params:
|
||||
stream = self._get_client().chat.completions.create(**params)
|
||||
else:
|
||||
|
@ -280,6 +291,11 @@ class CentMLInferenceAdapter(ModelRegistryHelper, Inference,
|
|||
async def _get_params(
|
||||
self, request: Union[ChatCompletionRequest,
|
||||
CompletionRequest]) -> dict:
|
||||
"""
|
||||
Build a unified set of parameters for both chat and non-chat requests.
|
||||
When a structured output is specified (response_format is not None), we force
|
||||
the use of a "messages" array even for CompletionRequests.
|
||||
"""
|
||||
input_dict = {}
|
||||
media_present = request_has_media(request)
|
||||
llama_model = self.get_llama_model(request.model)
|
||||
|
@ -290,6 +306,8 @@ class CentMLInferenceAdapter(ModelRegistryHelper, Inference,
|
|||
for m in request.messages
|
||||
]
|
||||
else:
|
||||
# ***** HACK: For CompletionRequests with structured output,
|
||||
# we simulate a chat conversation by wrapping the prompt as a single user message.
|
||||
prompt_str = await completion_request_to_prompt(request)
|
||||
input_dict["messages"] = [{
|
||||
"role": "user",
|
||||
|
@ -325,6 +343,10 @@ class CentMLInferenceAdapter(ModelRegistryHelper, Inference,
|
|||
logprobs: Optional[LogProbConfig],
|
||||
fmt: Optional[ResponseFormat],
|
||||
) -> dict:
|
||||
"""
|
||||
Build additional options such as sampling parameters and logprobs.
|
||||
Also translates our response_format into the format expected by CentML's API.
|
||||
"""
|
||||
options = get_sampling_options(sampling_params)
|
||||
if fmt:
|
||||
if fmt.type == ResponseFormatType.json_schema.value:
|
||||
|
@ -356,8 +378,8 @@ class CentMLInferenceAdapter(ModelRegistryHelper, Inference,
|
|||
output_dimension: Optional[int],
|
||||
contents: List[InterleavedContent],
|
||||
) -> EmbeddingsResponse:
|
||||
# this will come in future updates
|
||||
model = await self.model_store.get_model(model_id)
|
||||
# CentML does not support media for embeddings.
|
||||
assert all(not content_has_media(c) for c in contents), (
|
||||
"CentML does not support media for embeddings")
|
||||
resp = self._get_client().embeddings.create(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue