mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-11 20:40:40 +00:00
change to adapter
This commit is contained in:
parent
a454b53bda
commit
e2290a0096
1 changed files with 107 additions and 31 deletions
|
@ -1,3 +1,4 @@
|
||||||
|
# centml.py (updated)
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
# All rights reserved.
|
# All rights reserved.
|
||||||
#
|
#
|
||||||
|
@ -7,6 +8,7 @@
|
||||||
from typing import AsyncGenerator, List, Optional, Union
|
from typing import AsyncGenerator, List, Optional, Union
|
||||||
|
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
|
from pydantic import parse_obj_as
|
||||||
|
|
||||||
from llama_models.datatypes import CoreModelId
|
from llama_models.datatypes import CoreModelId
|
||||||
from llama_models.llama3.api.chat_format import ChatFormat
|
from llama_models.llama3.api.chat_format import ChatFormat
|
||||||
|
@ -53,8 +55,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
|
|
||||||
from .config import CentMLImplConfig
|
from .config import CentMLImplConfig
|
||||||
|
|
||||||
# Example model aliases that map from CentML’s
|
# Example model aliases that map from CentML’s published model identifiers
|
||||||
# published model identifiers to llama-stack's `CoreModelId`.
|
|
||||||
MODEL_ALIASES = [
|
MODEL_ALIASES = [
|
||||||
build_model_entry(
|
build_model_entry(
|
||||||
"meta-llama/Llama-3.2-3B-Instruct",
|
"meta-llama/Llama-3.2-3B-Instruct",
|
||||||
|
@ -129,8 +130,7 @@ class CentMLInferenceAdapter(ModelRegistryHelper, Inference,
|
||||||
model=model.provider_resource_id,
|
model=model.provider_resource_id,
|
||||||
content=content,
|
content=content,
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
# Completions.create() got an unexpected keyword argument 'response_format'
|
response_format=response_format,
|
||||||
#response_format=response_format,
|
|
||||||
stream=stream,
|
stream=stream,
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
)
|
)
|
||||||
|
@ -142,22 +142,77 @@ class CentMLInferenceAdapter(ModelRegistryHelper, Inference,
|
||||||
async def _nonstream_completion(
|
async def _nonstream_completion(
|
||||||
self, request: CompletionRequest) -> ChatCompletionResponse:
|
self, request: CompletionRequest) -> ChatCompletionResponse:
|
||||||
params = await self._get_params(request)
|
params = await self._get_params(request)
|
||||||
# Using the older "completions" route for non-chat
|
if request.response_format is not None:
|
||||||
response = self._get_client().completions.create(**params)
|
# For structured output, use the chat completions endpoint.
|
||||||
return process_completion_response(response)
|
response = self._get_client().chat.completions.create(**params)
|
||||||
|
try:
|
||||||
|
result = process_chat_completion_response(response, request)
|
||||||
|
except KeyError as e:
|
||||||
|
if str(e) == "'parameters'":
|
||||||
|
# CentML's structured output may not include a tool call.
|
||||||
|
# Use the raw message content as the structured JSON.
|
||||||
|
raw_content = response.choices[0].message.content
|
||||||
|
message_obj = parse_obj_as(
|
||||||
|
Message, {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": raw_content,
|
||||||
|
"stop_reason": "end_of_message"
|
||||||
|
})
|
||||||
|
result = ChatCompletionResponse(
|
||||||
|
completion_message=message_obj,
|
||||||
|
logprobs=None,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise
|
||||||
|
# If the processed content is still None, use the raw API content.
|
||||||
|
if result.completion_message.content is None:
|
||||||
|
raw_content = response.choices[0].message.content
|
||||||
|
if isinstance(result.completion_message, dict):
|
||||||
|
result.completion_message["content"] = raw_content
|
||||||
|
else:
|
||||||
|
updated_msg = result.completion_message.copy(
|
||||||
|
update={"content": raw_content})
|
||||||
|
result = result.copy(
|
||||||
|
update={"completion_message": updated_msg})
|
||||||
|
else:
|
||||||
|
response = self._get_client().completions.create(**params)
|
||||||
|
result = process_completion_response(response)
|
||||||
|
# If structured output returns token lists, join them.
|
||||||
|
if request.response_format is not None:
|
||||||
|
if isinstance(result.completion_message, dict):
|
||||||
|
content = result.completion_message.get("content")
|
||||||
|
if isinstance(content, list):
|
||||||
|
result.completion_message["content"] = "".join(content)
|
||||||
|
else:
|
||||||
|
if isinstance(result.completion_message.content, list):
|
||||||
|
updated_msg = result.completion_message.copy(update={
|
||||||
|
"content":
|
||||||
|
"".join(result.completion_message.content)
|
||||||
|
})
|
||||||
|
result = result.copy(
|
||||||
|
update={"completion_message": updated_msg})
|
||||||
|
return result
|
||||||
|
|
||||||
async def _stream_completion(self,
|
async def _stream_completion(self,
|
||||||
request: CompletionRequest) -> AsyncGenerator:
|
request: CompletionRequest) -> AsyncGenerator:
|
||||||
params = await self._get_params(request)
|
params = await self._get_params(request)
|
||||||
|
|
||||||
async def _to_async_generator():
|
async def _to_async_generator():
|
||||||
stream = self._get_client().completions.create(**params)
|
if request.response_format is not None:
|
||||||
|
stream = self._get_client().chat.completions.create(**params)
|
||||||
|
else:
|
||||||
|
stream = self._get_client().completions.create(**params)
|
||||||
for chunk in stream:
|
for chunk in stream:
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
stream = _to_async_generator()
|
stream = _to_async_generator()
|
||||||
async for chunk in process_completion_stream_response(stream):
|
if request.response_format is not None:
|
||||||
yield chunk
|
async for chunk in process_chat_completion_stream_response(
|
||||||
|
stream, request):
|
||||||
|
yield chunk
|
||||||
|
else:
|
||||||
|
async for chunk in process_completion_stream_response(stream):
|
||||||
|
yield chunk
|
||||||
|
|
||||||
#
|
#
|
||||||
# CHAT COMPLETION
|
# CHAT COMPLETION
|
||||||
|
@ -187,8 +242,7 @@ class CentMLInferenceAdapter(ModelRegistryHelper, Inference,
|
||||||
tools=tools or [],
|
tools=tools or [],
|
||||||
tool_choice=tool_choice,
|
tool_choice=tool_choice,
|
||||||
tool_prompt_format=tool_prompt_format,
|
tool_prompt_format=tool_prompt_format,
|
||||||
# Completions.create() got an unexpected keyword argument 'response_format'
|
response_format=response_format,
|
||||||
#response_format=response_format,
|
|
||||||
stream=stream,
|
stream=stream,
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
)
|
)
|
||||||
|
@ -200,15 +254,25 @@ class CentMLInferenceAdapter(ModelRegistryHelper, Inference,
|
||||||
async def _nonstream_chat_completion(
|
async def _nonstream_chat_completion(
|
||||||
self, request: ChatCompletionRequest) -> ChatCompletionResponse:
|
self, request: ChatCompletionRequest) -> ChatCompletionResponse:
|
||||||
params = await self._get_params(request)
|
params = await self._get_params(request)
|
||||||
|
|
||||||
# For chat requests, if "messages" is in params -> .chat.completions
|
|
||||||
if "messages" in params:
|
if "messages" in params:
|
||||||
response = self._get_client().chat.completions.create(**params)
|
response = self._get_client().chat.completions.create(**params)
|
||||||
else:
|
else:
|
||||||
# fallback if we ended up only with "prompt"
|
|
||||||
response = self._get_client().completions.create(**params)
|
response = self._get_client().completions.create(**params)
|
||||||
|
result = process_chat_completion_response(response, request)
|
||||||
return process_chat_completion_response(response, request)
|
if request.response_format is not None:
|
||||||
|
if isinstance(result.completion_message, dict):
|
||||||
|
content = result.completion_message.get("content")
|
||||||
|
if isinstance(content, list):
|
||||||
|
result.completion_message["content"] = "".join(content)
|
||||||
|
else:
|
||||||
|
if isinstance(result.completion_message.content, list):
|
||||||
|
updated_msg = result.completion_message.copy(update={
|
||||||
|
"content":
|
||||||
|
"".join(result.completion_message.content)
|
||||||
|
})
|
||||||
|
result = result.copy(
|
||||||
|
update={"completion_message": updated_msg})
|
||||||
|
return result
|
||||||
|
|
||||||
async def _stream_chat_completion(
|
async def _stream_chat_completion(
|
||||||
self, request: ChatCompletionRequest) -> AsyncGenerator:
|
self, request: ChatCompletionRequest) -> AsyncGenerator:
|
||||||
|
@ -237,18 +301,32 @@ class CentMLInferenceAdapter(ModelRegistryHelper, Inference,
|
||||||
input_dict = {}
|
input_dict = {}
|
||||||
media_present = request_has_media(request)
|
media_present = request_has_media(request)
|
||||||
llama_model = self.get_llama_model(request.model)
|
llama_model = self.get_llama_model(request.model)
|
||||||
if isinstance(request, ChatCompletionRequest):
|
if request.response_format is not None:
|
||||||
if media_present or not llama_model:
|
if isinstance(request, ChatCompletionRequest):
|
||||||
input_dict["messages"] = [
|
input_dict["messages"] = [
|
||||||
await convert_message_to_openai_dict(m)
|
await convert_message_to_openai_dict(m)
|
||||||
for m in request.messages
|
for m in request.messages
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
input_dict["prompt"] = await chat_completion_request_to_prompt(
|
prompt_str = await completion_request_to_prompt(request)
|
||||||
request, llama_model)
|
input_dict["messages"] = [{
|
||||||
|
"role": "user",
|
||||||
|
"content": prompt_str
|
||||||
|
}]
|
||||||
else:
|
else:
|
||||||
input_dict["prompt"] = await completion_request_to_prompt(request)
|
if isinstance(request, ChatCompletionRequest):
|
||||||
|
if media_present or not llama_model:
|
||||||
|
input_dict["messages"] = [
|
||||||
|
await convert_message_to_openai_dict(m)
|
||||||
|
for m in request.messages
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
input_dict[
|
||||||
|
"prompt"] = await chat_completion_request_to_prompt(
|
||||||
|
request, llama_model)
|
||||||
|
else:
|
||||||
|
input_dict["prompt"] = await completion_request_to_prompt(
|
||||||
|
request)
|
||||||
params = {
|
params = {
|
||||||
"model":
|
"model":
|
||||||
request.model,
|
request.model,
|
||||||
|
@ -263,26 +341,25 @@ class CentMLInferenceAdapter(ModelRegistryHelper, Inference,
|
||||||
self,
|
self,
|
||||||
sampling_params: Optional[SamplingParams],
|
sampling_params: Optional[SamplingParams],
|
||||||
logprobs: Optional[LogProbConfig],
|
logprobs: Optional[LogProbConfig],
|
||||||
fmt: ResponseFormat,
|
fmt: Optional[ResponseFormat],
|
||||||
) -> dict:
|
) -> dict:
|
||||||
options = get_sampling_options(sampling_params)
|
options = get_sampling_options(sampling_params)
|
||||||
if fmt:
|
if fmt:
|
||||||
if fmt.type == ResponseFormatType.json_schema.value:
|
if fmt.type == ResponseFormatType.json_schema.value:
|
||||||
options["response_format"] = {
|
options["response_format"] = {
|
||||||
"type": "json_object",
|
"type": "json_schema",
|
||||||
# CentML currently does not support guided decoding,
|
"json_schema": {
|
||||||
# the following setting is currently ignored by the server.
|
"name": "schema",
|
||||||
"schema": fmt.json_schema,
|
"schema": fmt.json_schema
|
||||||
|
},
|
||||||
}
|
}
|
||||||
elif fmt.type == ResponseFormatType.grammar.value:
|
elif fmt.type == ResponseFormatType.grammar.value:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"Grammar response format not supported yet")
|
"Grammar response format not supported yet")
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown response format {fmt.type}")
|
raise ValueError(f"Unknown response format {fmt.type}")
|
||||||
|
|
||||||
if logprobs and logprobs.top_k:
|
if logprobs and logprobs.top_k:
|
||||||
options["logprobs"] = logprobs.top_k
|
options["logprobs"] = logprobs.top_k
|
||||||
|
|
||||||
return options
|
return options
|
||||||
|
|
||||||
#
|
#
|
||||||
|
@ -301,7 +378,6 @@ class CentMLInferenceAdapter(ModelRegistryHelper, Inference,
|
||||||
# CentML does not support media for embeddings.
|
# CentML does not support media for embeddings.
|
||||||
assert all(not content_has_media(c) for c in contents), (
|
assert all(not content_has_media(c) for c in contents), (
|
||||||
"CentML does not support media for embeddings")
|
"CentML does not support media for embeddings")
|
||||||
|
|
||||||
resp = self._get_client().embeddings.create(
|
resp = self._get_client().embeddings.create(
|
||||||
model=model.provider_resource_id,
|
model=model.provider_resource_id,
|
||||||
input=[interleaved_content_as_str(c) for c in contents],
|
input=[interleaved_content_as_str(c) for c in contents],
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue