mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-11 20:40:40 +00:00
fix endpoint
This commit is contained in:
parent
98549b826d
commit
e31a52b26e
1 changed files with 44 additions and 46 deletions
|
@ -8,6 +8,7 @@ from typing import AsyncGenerator, List, Optional, Union
|
||||||
|
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
|
|
||||||
|
from llama_stack import logcat
|
||||||
from llama_models.datatypes import CoreModelId
|
from llama_models.datatypes import CoreModelId
|
||||||
from llama_models.llama3.api.chat_format import ChatFormat
|
from llama_models.llama3.api.chat_format import ChatFormat
|
||||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
|
@ -75,7 +76,6 @@ class CentMLInferenceAdapter(
|
||||||
def __init__(self, config: CentMLImplConfig) -> None:
|
def __init__(self, config: CentMLImplConfig) -> None:
|
||||||
super().__init__(MODEL_ALIASES)
|
super().__init__(MODEL_ALIASES)
|
||||||
self.config = config
|
self.config = config
|
||||||
self.formatter = ChatFormat(Tokenizer.get_instance())
|
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
@ -127,7 +127,8 @@ class CentMLInferenceAdapter(
|
||||||
model=model.provider_resource_id,
|
model=model.provider_resource_id,
|
||||||
content=content,
|
content=content,
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
response_format=response_format,
|
# Completions.create() got an unexpected keyword argument 'response_format'
|
||||||
|
#response_format=response_format,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
)
|
)
|
||||||
|
@ -142,7 +143,7 @@ class CentMLInferenceAdapter(
|
||||||
params = await self._get_params(request)
|
params = await self._get_params(request)
|
||||||
# Using the older "completions" route for non-chat
|
# Using the older "completions" route for non-chat
|
||||||
response = self._get_client().completions.create(**params)
|
response = self._get_client().completions.create(**params)
|
||||||
return process_completion_response(response, self.formatter)
|
return process_completion_response(response)
|
||||||
|
|
||||||
async def _stream_completion(
|
async def _stream_completion(
|
||||||
self, request: CompletionRequest
|
self, request: CompletionRequest
|
||||||
|
@ -156,7 +157,7 @@ class CentMLInferenceAdapter(
|
||||||
|
|
||||||
stream = _to_async_generator()
|
stream = _to_async_generator()
|
||||||
async for chunk in process_completion_stream_response(
|
async for chunk in process_completion_stream_response(
|
||||||
stream, self.formatter
|
stream
|
||||||
):
|
):
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
|
@ -188,7 +189,8 @@ class CentMLInferenceAdapter(
|
||||||
tools=tools or [],
|
tools=tools or [],
|
||||||
tool_choice=tool_choice,
|
tool_choice=tool_choice,
|
||||||
tool_prompt_format=tool_prompt_format,
|
tool_prompt_format=tool_prompt_format,
|
||||||
response_format=response_format,
|
# Completions.create() got an unexpected keyword argument 'response_format'
|
||||||
|
#response_format=response_format,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
)
|
)
|
||||||
|
@ -209,7 +211,7 @@ class CentMLInferenceAdapter(
|
||||||
# fallback if we ended up only with "prompt"
|
# fallback if we ended up only with "prompt"
|
||||||
response = self._get_client().completions.create(**params)
|
response = self._get_client().completions.create(**params)
|
||||||
|
|
||||||
return process_chat_completion_response(response, self.formatter)
|
return process_chat_completion_response(response, request)
|
||||||
|
|
||||||
async def _stream_chat_completion(
|
async def _stream_chat_completion(
|
||||||
self, request: ChatCompletionRequest
|
self, request: ChatCompletionRequest
|
||||||
|
@ -226,62 +228,34 @@ class CentMLInferenceAdapter(
|
||||||
|
|
||||||
stream = _to_async_generator()
|
stream = _to_async_generator()
|
||||||
async for chunk in process_chat_completion_stream_response(
|
async for chunk in process_chat_completion_stream_response(
|
||||||
stream, self.formatter
|
stream, request):
|
||||||
):
|
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
#
|
#
|
||||||
# HELPER METHODS
|
# HELPER METHODS
|
||||||
#
|
#
|
||||||
|
|
||||||
async def _get_params(
|
async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequest]) -> dict:
|
||||||
self, request: Union[ChatCompletionRequest, CompletionRequest]
|
|
||||||
) -> dict:
|
|
||||||
|
|
||||||
"""
|
|
||||||
Build the 'params' dict that the OpenAI (CentML) client expects.
|
|
||||||
For chat requests, we always prefer "messages" so that it calls
|
|
||||||
the chat endpoint properly.
|
|
||||||
"""
|
|
||||||
input_dict = {}
|
input_dict = {}
|
||||||
media_present = request_has_media(request)
|
media_present = request_has_media(request)
|
||||||
|
llama_model = self.get_llama_model(request.model)
|
||||||
if isinstance(request, ChatCompletionRequest):
|
if isinstance(request, ChatCompletionRequest):
|
||||||
# For chat requests, always build "messages" from the user messages
|
if media_present or not llama_model:
|
||||||
input_dict["messages"] = [
|
input_dict["messages"] = [await convert_message_to_openai_dict(m) for m in request.messages]
|
||||||
await convert_message_to_openai_dict(m)
|
else:
|
||||||
for m in request.messages
|
input_dict["prompt"] = await chat_completion_request_to_prompt(request, llama_model)
|
||||||
]
|
|
||||||
else:
|
else:
|
||||||
# Non-chat (CompletionRequest)
|
input_dict["prompt"] = await completion_request_to_prompt(request)
|
||||||
assert not media_present, "CentML does not support media for completions"
|
|
||||||
input_dict["prompt"] = await completion_request_to_prompt(
|
|
||||||
request, self.formatter)
|
|
||||||
|
|
||||||
params = {
|
params = {
|
||||||
"model":
|
"model": request.model,
|
||||||
request.model,
|
|
||||||
**input_dict,
|
**input_dict,
|
||||||
"stream":
|
"stream": request.stream,
|
||||||
request.stream,
|
**self._build_options(request.sampling_params, request.logprobs, request.response_format),
|
||||||
**self._build_options(request.sampling_params, request.response_format),
|
|
||||||
}
|
}
|
||||||
|
logcat.debug("inference", f"params to centml: {params}")
|
||||||
# For non-chat completions (i.e. when using a "prompt"), CentML's
|
|
||||||
# completions endpoint does not support the response_format parameter.
|
|
||||||
if "prompt" in params and "response_format" in params:
|
|
||||||
del params["response_format"]
|
|
||||||
|
|
||||||
# For chat completions with structured output, CentML requires
|
|
||||||
# guided decoding settings to use num_scheduler_steps=1 and spec_enabled=False.
|
|
||||||
# Override these if a response_format was requested.
|
|
||||||
if "messages" in params and request.response_format:
|
|
||||||
params["num_scheduler_steps"] = 1
|
|
||||||
params["spec_enabled"] = False
|
|
||||||
|
|
||||||
return params
|
return params
|
||||||
|
|
||||||
|
|
||||||
def _build_options(
|
def _build_options(
|
||||||
self,
|
self,
|
||||||
sampling_params: Optional[SamplingParams],
|
sampling_params: Optional[SamplingParams],
|
||||||
|
@ -308,6 +282,30 @@ class CentMLInferenceAdapter(
|
||||||
|
|
||||||
return options
|
return options
|
||||||
|
|
||||||
|
def _build_options(
|
||||||
|
self,
|
||||||
|
sampling_params: Optional[SamplingParams],
|
||||||
|
logprobs: Optional[LogProbConfig],
|
||||||
|
fmt: ResponseFormat,
|
||||||
|
) -> dict:
|
||||||
|
options = get_sampling_options(sampling_params)
|
||||||
|
if fmt:
|
||||||
|
if fmt.type == ResponseFormatType.json_schema.value:
|
||||||
|
options["response_format"] = {
|
||||||
|
"type": "json_object",
|
||||||
|
"schema": fmt.json_schema,
|
||||||
|
}
|
||||||
|
elif fmt.type == ResponseFormatType.grammar.value:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Grammar response format not supported yet")
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown response format {fmt.type}")
|
||||||
|
|
||||||
|
if logprobs and logprobs.top_k:
|
||||||
|
options["logprobs"] = 1
|
||||||
|
|
||||||
|
return options
|
||||||
|
|
||||||
#
|
#
|
||||||
# EMBEDDINGS
|
# EMBEDDINGS
|
||||||
#
|
#
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue