mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-31 11:00:01 +00:00
358 lines
14 KiB
Python
358 lines
14 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
from typing import AsyncGenerator, List, Optional, Union
|
|
|
|
from llama_models.datatypes import CoreModelId
|
|
from openai import OpenAI
|
|
|
|
from llama_stack.apis.common.content_types import InterleavedContent
|
|
from llama_stack.apis.inference import (
|
|
ChatCompletionRequest,
|
|
ChatCompletionResponse,
|
|
CompletionRequest,
|
|
CompletionResponse,
|
|
EmbeddingsResponse,
|
|
Inference,
|
|
LogProbConfig,
|
|
Message,
|
|
ResponseFormat,
|
|
ResponseFormatType,
|
|
SamplingParams,
|
|
ToolChoice,
|
|
ToolConfig,
|
|
ToolDefinition,
|
|
ToolPromptFormat,
|
|
)
|
|
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
|
from llama_stack.providers.utils.inference.model_registry import (
|
|
ModelRegistryHelper,
|
|
build_model_entry,
|
|
)
|
|
from llama_stack.providers.utils.inference.openai_compat import (
|
|
convert_message_to_openai_dict,
|
|
get_sampling_options,
|
|
process_chat_completion_response,
|
|
process_chat_completion_stream_response,
|
|
process_completion_response,
|
|
process_completion_stream_response,
|
|
)
|
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
|
chat_completion_request_to_prompt,
|
|
completion_request_to_prompt,
|
|
content_has_media,
|
|
request_has_media,
|
|
)
|
|
|
|
from .config import CentMLImplConfig
|
|
|
|
# Update this if list of model changes.
|
|
MODEL_ALIASES = [
|
|
build_model_entry(
|
|
"meta-llama/Llama-3.2-3B-Instruct",
|
|
CoreModelId.llama3_2_3b_instruct.value,
|
|
),
|
|
build_model_entry(
|
|
"meta-llama/Llama-3.3-70B-Instruct",
|
|
CoreModelId.llama3_3_70b_instruct.value,
|
|
),
|
|
]
|
|
|
|
|
|
class CentMLInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData):
|
|
"""
|
|
Adapter to use CentML's serverless inference endpoints,
|
|
which adhere to the OpenAI chat/completions API spec,
|
|
inside llama-stack.
|
|
"""
|
|
|
|
def __init__(self, config: CentMLImplConfig) -> None:
|
|
super().__init__(MODEL_ALIASES)
|
|
self.config = config
|
|
|
|
async def initialize(self) -> None:
|
|
pass
|
|
|
|
async def shutdown(self) -> None:
|
|
pass
|
|
|
|
def _get_api_key(self) -> str:
|
|
"""
|
|
Obtain the CentML API key either from the adapter config
|
|
or from the dynamic provider data in request headers.
|
|
"""
|
|
if self.config.api_key is not None:
|
|
return self.config.api_key.get_secret_value()
|
|
else:
|
|
provider_data = self.get_request_provider_data()
|
|
if provider_data is None or not provider_data.centml_api_key:
|
|
raise ValueError(
|
|
'Pass CentML API Key in the header X-LlamaStack-ProviderData as { "centml_api_key": "<your-api-key>" }'
|
|
)
|
|
return provider_data.centml_api_key
|
|
|
|
def _get_client(self) -> OpenAI:
|
|
"""
|
|
Creates an OpenAI-compatible client pointing to CentML's base URL,
|
|
using the user's CentML API key.
|
|
"""
|
|
api_key = self._get_api_key()
|
|
return OpenAI(api_key=api_key, base_url=self.config.url)
|
|
|
|
#
|
|
# COMPLETION (non-chat)
|
|
#
|
|
|
|
async def completion(
|
|
self,
|
|
model_id: str,
|
|
content: InterleavedContent,
|
|
sampling_params: Optional[SamplingParams] = None, # Avoid function call in default argument.
|
|
response_format: Optional[ResponseFormat] = None,
|
|
stream: Optional[bool] = False,
|
|
logprobs: Optional[LogProbConfig] = None,
|
|
) -> AsyncGenerator:
|
|
"""
|
|
For "completion" style requests (non-chat).
|
|
"""
|
|
# Instantiate sampling_params if not provided.
|
|
if sampling_params is None:
|
|
sampling_params = SamplingParams()
|
|
model = await self.model_store.get_model(model_id)
|
|
request = CompletionRequest(
|
|
model=model.provider_resource_id,
|
|
content=content,
|
|
sampling_params=sampling_params,
|
|
response_format=response_format,
|
|
stream=stream,
|
|
logprobs=logprobs,
|
|
)
|
|
if stream:
|
|
return self._stream_completion(request)
|
|
else:
|
|
return await self._nonstream_completion(request)
|
|
|
|
async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse:
|
|
"""
|
|
Process non-streaming completion requests.
|
|
|
|
If a structured output is specified (e.g. JSON schema),
|
|
the adapter calls the chat completions endpoint and then
|
|
converts the chat response into a plain CompletionResponse.
|
|
Otherwise, it uses the regular completions endpoint.
|
|
"""
|
|
params = await self._get_params(request)
|
|
if request.response_format is not None:
|
|
# ***** HACK: Use the chat completions endpoint even for non-chat completions
|
|
# This is necessary because CentML's structured output (JSON schema) support
|
|
# is only available via the chat API. However, our API expects a CompletionResponse.
|
|
response = self._get_client().chat.completions.create(**params)
|
|
choice = response.choices[0]
|
|
message = choice.message
|
|
# If message.content is returned as a list of tokens, join them into a string.
|
|
content = message.content if not isinstance(message.content, list) else "".join(message.content)
|
|
return CompletionResponse(
|
|
content=content,
|
|
stop_reason="end_of_message", # ***** HACK: Hard-coded stop_reason because the chat API doesn't return one.
|
|
logprobs=None,
|
|
)
|
|
else:
|
|
# ***** HACK: For non-structured outputs, ensure we use the completions endpoint.
|
|
# _get_params may include a "messages" key due to our unified parameter builder.
|
|
# We remove "messages" and instead set a "prompt" since the completions endpoint expects it.
|
|
prompt_str = await completion_request_to_prompt(request)
|
|
if "messages" in params:
|
|
del params["messages"]
|
|
params["prompt"] = prompt_str
|
|
response = self._get_client().completions.create(**params)
|
|
result = process_completion_response(response)
|
|
# Join tokenized content if needed.
|
|
if isinstance(result.content, list):
|
|
result.content = "".join(result.content)
|
|
return result
|
|
|
|
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
|
params = await self._get_params(request)
|
|
|
|
async def _to_async_generator():
|
|
# ***** HACK: For streaming structured outputs, use the chat completions endpoint.
|
|
# Otherwise, use the regular completions endpoint.
|
|
if request.response_format is not None:
|
|
stream = self._get_client().chat.completions.create(**params)
|
|
else:
|
|
stream = self._get_client().completions.create(**params)
|
|
for chunk in stream:
|
|
yield chunk
|
|
|
|
stream = _to_async_generator()
|
|
if request.response_format is not None:
|
|
async for chunk in process_chat_completion_stream_response(stream, request):
|
|
yield chunk
|
|
else:
|
|
async for chunk in process_completion_stream_response(stream):
|
|
yield chunk
|
|
|
|
#
|
|
# CHAT COMPLETION
|
|
#
|
|
|
|
async def chat_completion(
|
|
self,
|
|
model_id: str,
|
|
messages: List[Message],
|
|
sampling_params: Optional[SamplingParams] = None, # Avoid function call in default argument.
|
|
tools: Optional[List[ToolDefinition]] = None,
|
|
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
|
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
|
response_format: Optional[ResponseFormat] = None,
|
|
stream: Optional[bool] = False,
|
|
logprobs: Optional[LogProbConfig] = None,
|
|
tool_config: Optional[ToolConfig] = None,
|
|
) -> AsyncGenerator:
|
|
"""
|
|
For "chat completion" style requests.
|
|
"""
|
|
# Instantiate sampling_params if not provided.
|
|
if sampling_params is None:
|
|
sampling_params = SamplingParams()
|
|
model = await self.model_store.get_model(model_id)
|
|
request = ChatCompletionRequest(
|
|
model=model.provider_resource_id,
|
|
messages=messages,
|
|
sampling_params=sampling_params,
|
|
tools=tools or [],
|
|
tool_choice=tool_choice,
|
|
tool_prompt_format=tool_prompt_format,
|
|
response_format=response_format,
|
|
stream=stream,
|
|
logprobs=logprobs,
|
|
)
|
|
if stream:
|
|
return self._stream_chat_completion(request)
|
|
else:
|
|
return await self._nonstream_chat_completion(request)
|
|
|
|
async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
|
|
params = await self._get_params(request)
|
|
# Use the chat completions endpoint if "messages" key is present.
|
|
if "messages" in params:
|
|
response = self._get_client().chat.completions.create(**params)
|
|
else:
|
|
response = self._get_client().completions.create(**params)
|
|
result = process_chat_completion_response(response, request)
|
|
# ***** HACK: Sometimes the returned content is tokenized as a list.
|
|
# We join the tokens into a single string to produce a unified output.
|
|
if request.response_format is not None:
|
|
if isinstance(result.completion_message, dict):
|
|
content = result.completion_message.get("content")
|
|
if isinstance(content, list):
|
|
result.completion_message["content"] = "".join(content)
|
|
else:
|
|
if isinstance(result.completion_message.content, list):
|
|
updated_msg = result.completion_message.copy(
|
|
update={"content": "".join(result.completion_message.content)}
|
|
)
|
|
result = result.copy(update={"completion_message": updated_msg})
|
|
return result
|
|
|
|
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
|
|
params = await self._get_params(request)
|
|
|
|
async def _to_async_generator():
|
|
# ***** HACK: Use the chat completions endpoint if "messages" key is present.
|
|
if "messages" in params:
|
|
stream = self._get_client().chat.completions.create(**params)
|
|
else:
|
|
stream = self._get_client().completions.create(**params)
|
|
for chunk in stream:
|
|
yield chunk
|
|
|
|
stream = _to_async_generator()
|
|
async for chunk in process_chat_completion_stream_response(stream, request):
|
|
yield chunk
|
|
|
|
#
|
|
# HELPER METHODS
|
|
#
|
|
|
|
async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequest]) -> dict:
|
|
"""
|
|
Build a unified set of parameters for both chat and non-chat requests.
|
|
When a structured output is specified (response_format is not None), we force
|
|
the use of a "messages" array even for CompletionRequests.
|
|
"""
|
|
input_dict = {}
|
|
media_present = request_has_media(request)
|
|
llama_model = self.get_llama_model(request.model)
|
|
if request.response_format is not None:
|
|
if isinstance(request, ChatCompletionRequest):
|
|
input_dict["messages"] = [await convert_message_to_openai_dict(m) for m in request.messages]
|
|
else:
|
|
# ***** HACK: For CompletionRequests with structured output,
|
|
# we simulate a chat conversation by wrapping the prompt as a single user message.
|
|
prompt_str = await completion_request_to_prompt(request)
|
|
input_dict["messages"] = [{"role": "user", "content": prompt_str}]
|
|
else:
|
|
if isinstance(request, ChatCompletionRequest):
|
|
if media_present or not llama_model:
|
|
input_dict["messages"] = [await convert_message_to_openai_dict(m) for m in request.messages]
|
|
else:
|
|
input_dict["prompt"] = await chat_completion_request_to_prompt(request, llama_model)
|
|
else:
|
|
input_dict["prompt"] = await completion_request_to_prompt(request)
|
|
params = {
|
|
"model": request.model,
|
|
**input_dict,
|
|
"stream": request.stream,
|
|
**self._build_options(request.sampling_params, request.logprobs, request.response_format),
|
|
}
|
|
return params
|
|
|
|
def _build_options(
|
|
self,
|
|
sampling_params: Optional[SamplingParams],
|
|
logprobs: Optional[LogProbConfig],
|
|
fmt: Optional[ResponseFormat],
|
|
) -> dict:
|
|
"""
|
|
Build additional options such as sampling parameters and logprobs.
|
|
Also translates our response_format into the format expected by CentML's API.
|
|
"""
|
|
options = get_sampling_options(sampling_params)
|
|
if fmt:
|
|
if fmt.type == ResponseFormatType.json_schema.value:
|
|
options["response_format"] = {
|
|
"type": "json_schema",
|
|
"json_schema": {"name": "schema", "schema": fmt.json_schema},
|
|
}
|
|
elif fmt.type == ResponseFormatType.grammar.value:
|
|
raise NotImplementedError("Grammar response format not supported yet")
|
|
else:
|
|
raise ValueError(f"Unknown response format {fmt.type}")
|
|
if logprobs and logprobs.top_k:
|
|
options["logprobs"] = logprobs.top_k
|
|
return options
|
|
|
|
#
|
|
# EMBEDDINGS
|
|
#
|
|
|
|
async def embeddings(
|
|
self,
|
|
task_type: str,
|
|
model_id: str,
|
|
text_truncation: Optional[str],
|
|
output_dimension: Optional[int],
|
|
contents: List[InterleavedContent],
|
|
) -> EmbeddingsResponse:
|
|
# ***** HACK/ASSERT: CentML does not support media for embeddings.
|
|
# We assert here to catch any cases where media is inadvertently included.
|
|
# model = await self.model_store.get_model(model_id)
|
|
assert all(not content_has_media(c) for c in contents), "CentML does not support media for embeddings"
|
|
# resp = self._get_client().embeddings.create(
|
|
# model=model.provider_resource_id,
|
|
# input=[interleaved_content_as_str(c) for c in contents],
|
|
# )
|