mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-11 20:40:40 +00:00
add comments
This commit is contained in:
parent
3ab672dcda
commit
0cef9adda5
1 changed files with 28 additions and 6 deletions
|
@ -1,4 +1,3 @@
|
||||||
# centml.py (updated)
|
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
# All rights reserved.
|
# All rights reserved.
|
||||||
#
|
#
|
||||||
|
@ -55,7 +54,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
|
|
||||||
from .config import CentMLImplConfig
|
from .config import CentMLImplConfig
|
||||||
|
|
||||||
# Example model aliases that map from CentML’s published model identifiers
|
# Update this if list of model changes.
|
||||||
MODEL_ALIASES = [
|
MODEL_ALIASES = [
|
||||||
build_model_entry(
|
build_model_entry(
|
||||||
"meta-llama/Llama-3.2-3B-Instruct",
|
"meta-llama/Llama-3.2-3B-Instruct",
|
||||||
|
@ -151,26 +150,32 @@ class CentMLInferenceAdapter(ModelRegistryHelper, Inference,
|
||||||
"""
|
"""
|
||||||
params = await self._get_params(request)
|
params = await self._get_params(request)
|
||||||
if request.response_format is not None:
|
if request.response_format is not None:
|
||||||
# Use the chat completions endpoint for structured output.
|
# ***** HACK: Use the chat completions endpoint even for non-chat completions
|
||||||
|
# This is necessary because CentML's structured output (JSON schema) support
|
||||||
|
# is only available via the chat API. However, our API expects a CompletionResponse.
|
||||||
response = self._get_client().chat.completions.create(**params)
|
response = self._get_client().chat.completions.create(**params)
|
||||||
choice = response.choices[0]
|
choice = response.choices[0]
|
||||||
message = choice.message
|
message = choice.message
|
||||||
|
# If message.content is returned as a list of tokens, join them into a string.
|
||||||
content = message.content if not isinstance(
|
content = message.content if not isinstance(
|
||||||
message.content, list) else "".join(message.content)
|
message.content, list) else "".join(message.content)
|
||||||
return CompletionResponse(
|
return CompletionResponse(
|
||||||
content=content,
|
content=content,
|
||||||
stop_reason=
|
stop_reason=
|
||||||
"end_of_message", # hard code for now. need to fix later.
|
"end_of_message", # ***** HACK: Hard-coded stop_reason because the chat API doesn't return one.
|
||||||
logprobs=None,
|
logprobs=None,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Use the completions endpoint with a prompt.
|
# ***** HACK: For non-structured outputs, ensure we use the completions endpoint.
|
||||||
|
# _get_params may include a "messages" key due to our unified parameter builder.
|
||||||
|
# We remove "messages" and instead set a "prompt" since the completions endpoint expects it.
|
||||||
prompt_str = await completion_request_to_prompt(request)
|
prompt_str = await completion_request_to_prompt(request)
|
||||||
if "messages" in params:
|
if "messages" in params:
|
||||||
del params["messages"]
|
del params["messages"]
|
||||||
params["prompt"] = prompt_str
|
params["prompt"] = prompt_str
|
||||||
response = self._get_client().completions.create(**params)
|
response = self._get_client().completions.create(**params)
|
||||||
result = process_completion_response(response)
|
result = process_completion_response(response)
|
||||||
|
# Join tokenized content if needed.
|
||||||
if isinstance(result.content, list):
|
if isinstance(result.content, list):
|
||||||
result.content = "".join(result.content)
|
result.content = "".join(result.content)
|
||||||
return result
|
return result
|
||||||
|
@ -180,6 +185,8 @@ class CentMLInferenceAdapter(ModelRegistryHelper, Inference,
|
||||||
params = await self._get_params(request)
|
params = await self._get_params(request)
|
||||||
|
|
||||||
async def _to_async_generator():
|
async def _to_async_generator():
|
||||||
|
# ***** HACK: For streaming structured outputs, use the chat completions endpoint.
|
||||||
|
# Otherwise, use the regular completions endpoint.
|
||||||
if request.response_format is not None:
|
if request.response_format is not None:
|
||||||
stream = self._get_client().chat.completions.create(**params)
|
stream = self._get_client().chat.completions.create(**params)
|
||||||
else:
|
else:
|
||||||
|
@ -236,11 +243,14 @@ class CentMLInferenceAdapter(ModelRegistryHelper, Inference,
|
||||||
async def _nonstream_chat_completion(
|
async def _nonstream_chat_completion(
|
||||||
self, request: ChatCompletionRequest) -> ChatCompletionResponse:
|
self, request: ChatCompletionRequest) -> ChatCompletionResponse:
|
||||||
params = await self._get_params(request)
|
params = await self._get_params(request)
|
||||||
|
# Use the chat completions endpoint if "messages" key is present.
|
||||||
if "messages" in params:
|
if "messages" in params:
|
||||||
response = self._get_client().chat.completions.create(**params)
|
response = self._get_client().chat.completions.create(**params)
|
||||||
else:
|
else:
|
||||||
response = self._get_client().completions.create(**params)
|
response = self._get_client().completions.create(**params)
|
||||||
result = process_chat_completion_response(response, request)
|
result = process_chat_completion_response(response, request)
|
||||||
|
# ***** HACK: Sometimes the returned content is tokenized as a list.
|
||||||
|
# We join the tokens into a single string to produce a unified output.
|
||||||
if request.response_format is not None:
|
if request.response_format is not None:
|
||||||
if isinstance(result.completion_message, dict):
|
if isinstance(result.completion_message, dict):
|
||||||
content = result.completion_message.get("content")
|
content = result.completion_message.get("content")
|
||||||
|
@ -261,6 +271,7 @@ class CentMLInferenceAdapter(ModelRegistryHelper, Inference,
|
||||||
params = await self._get_params(request)
|
params = await self._get_params(request)
|
||||||
|
|
||||||
async def _to_async_generator():
|
async def _to_async_generator():
|
||||||
|
# ***** HACK: Use the chat completions endpoint if "messages" key is present.
|
||||||
if "messages" in params:
|
if "messages" in params:
|
||||||
stream = self._get_client().chat.completions.create(**params)
|
stream = self._get_client().chat.completions.create(**params)
|
||||||
else:
|
else:
|
||||||
|
@ -280,6 +291,11 @@ class CentMLInferenceAdapter(ModelRegistryHelper, Inference,
|
||||||
async def _get_params(
|
async def _get_params(
|
||||||
self, request: Union[ChatCompletionRequest,
|
self, request: Union[ChatCompletionRequest,
|
||||||
CompletionRequest]) -> dict:
|
CompletionRequest]) -> dict:
|
||||||
|
"""
|
||||||
|
Build a unified set of parameters for both chat and non-chat requests.
|
||||||
|
When a structured output is specified (response_format is not None), we force
|
||||||
|
the use of a "messages" array even for CompletionRequests.
|
||||||
|
"""
|
||||||
input_dict = {}
|
input_dict = {}
|
||||||
media_present = request_has_media(request)
|
media_present = request_has_media(request)
|
||||||
llama_model = self.get_llama_model(request.model)
|
llama_model = self.get_llama_model(request.model)
|
||||||
|
@ -290,6 +306,8 @@ class CentMLInferenceAdapter(ModelRegistryHelper, Inference,
|
||||||
for m in request.messages
|
for m in request.messages
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
|
# ***** HACK: For CompletionRequests with structured output,
|
||||||
|
# we simulate a chat conversation by wrapping the prompt as a single user message.
|
||||||
prompt_str = await completion_request_to_prompt(request)
|
prompt_str = await completion_request_to_prompt(request)
|
||||||
input_dict["messages"] = [{
|
input_dict["messages"] = [{
|
||||||
"role": "user",
|
"role": "user",
|
||||||
|
@ -325,6 +343,10 @@ class CentMLInferenceAdapter(ModelRegistryHelper, Inference,
|
||||||
logprobs: Optional[LogProbConfig],
|
logprobs: Optional[LogProbConfig],
|
||||||
fmt: Optional[ResponseFormat],
|
fmt: Optional[ResponseFormat],
|
||||||
) -> dict:
|
) -> dict:
|
||||||
|
"""
|
||||||
|
Build additional options such as sampling parameters and logprobs.
|
||||||
|
Also translates our response_format into the format expected by CentML's API.
|
||||||
|
"""
|
||||||
options = get_sampling_options(sampling_params)
|
options = get_sampling_options(sampling_params)
|
||||||
if fmt:
|
if fmt:
|
||||||
if fmt.type == ResponseFormatType.json_schema.value:
|
if fmt.type == ResponseFormatType.json_schema.value:
|
||||||
|
@ -356,8 +378,8 @@ class CentMLInferenceAdapter(ModelRegistryHelper, Inference,
|
||||||
output_dimension: Optional[int],
|
output_dimension: Optional[int],
|
||||||
contents: List[InterleavedContent],
|
contents: List[InterleavedContent],
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
|
# this will come in future updates
|
||||||
model = await self.model_store.get_model(model_id)
|
model = await self.model_store.get_model(model_id)
|
||||||
# CentML does not support media for embeddings.
|
|
||||||
assert all(not content_has_media(c) for c in contents), (
|
assert all(not content_has_media(c) for c in contents), (
|
||||||
"CentML does not support media for embeddings")
|
"CentML does not support media for embeddings")
|
||||||
resp = self._get_client().embeddings.create(
|
resp = self._get_client().embeddings.create(
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue