change to adapter

This commit is contained in:
Honglin Cao 2025-03-11 15:10:31 -04:00
parent a454b53bda
commit e2290a0096

View file

@ -1,3 +1,4 @@
# centml.py (updated)
# Copyright (c) Meta Platforms, Inc. and affiliates. # Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved. # All rights reserved.
# #
@ -7,6 +8,7 @@
from typing import AsyncGenerator, List, Optional, Union from typing import AsyncGenerator, List, Optional, Union
from openai import OpenAI from openai import OpenAI
from pydantic import parse_obj_as
from llama_models.datatypes import CoreModelId from llama_models.datatypes import CoreModelId
from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.chat_format import ChatFormat
@ -53,8 +55,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
from .config import CentMLImplConfig from .config import CentMLImplConfig
# Example model aliases that map from CentMLs # Example model aliases that map from CentMLs published model identifiers
# published model identifiers to llama-stack's `CoreModelId`.
MODEL_ALIASES = [ MODEL_ALIASES = [
build_model_entry( build_model_entry(
"meta-llama/Llama-3.2-3B-Instruct", "meta-llama/Llama-3.2-3B-Instruct",
@ -129,8 +130,7 @@ class CentMLInferenceAdapter(ModelRegistryHelper, Inference,
model=model.provider_resource_id, model=model.provider_resource_id,
content=content, content=content,
sampling_params=sampling_params, sampling_params=sampling_params,
# Completions.create() got an unexpected keyword argument 'response_format' response_format=response_format,
#response_format=response_format,
stream=stream, stream=stream,
logprobs=logprobs, logprobs=logprobs,
) )
@ -142,22 +142,77 @@ class CentMLInferenceAdapter(ModelRegistryHelper, Inference,
async def _nonstream_completion( async def _nonstream_completion(
self, request: CompletionRequest) -> ChatCompletionResponse: self, request: CompletionRequest) -> ChatCompletionResponse:
params = await self._get_params(request) params = await self._get_params(request)
# Using the older "completions" route for non-chat if request.response_format is not None:
response = self._get_client().completions.create(**params) # For structured output, use the chat completions endpoint.
return process_completion_response(response) response = self._get_client().chat.completions.create(**params)
try:
result = process_chat_completion_response(response, request)
except KeyError as e:
if str(e) == "'parameters'":
# CentML's structured output may not include a tool call.
# Use the raw message content as the structured JSON.
raw_content = response.choices[0].message.content
message_obj = parse_obj_as(
Message, {
"role": "assistant",
"content": raw_content,
"stop_reason": "end_of_message"
})
result = ChatCompletionResponse(
completion_message=message_obj,
logprobs=None,
)
else:
raise
# If the processed content is still None, use the raw API content.
if result.completion_message.content is None:
raw_content = response.choices[0].message.content
if isinstance(result.completion_message, dict):
result.completion_message["content"] = raw_content
else:
updated_msg = result.completion_message.copy(
update={"content": raw_content})
result = result.copy(
update={"completion_message": updated_msg})
else:
response = self._get_client().completions.create(**params)
result = process_completion_response(response)
# If structured output returns token lists, join them.
if request.response_format is not None:
if isinstance(result.completion_message, dict):
content = result.completion_message.get("content")
if isinstance(content, list):
result.completion_message["content"] = "".join(content)
else:
if isinstance(result.completion_message.content, list):
updated_msg = result.completion_message.copy(update={
"content":
"".join(result.completion_message.content)
})
result = result.copy(
update={"completion_message": updated_msg})
return result
async def _stream_completion(self, async def _stream_completion(self,
request: CompletionRequest) -> AsyncGenerator: request: CompletionRequest) -> AsyncGenerator:
params = await self._get_params(request) params = await self._get_params(request)
async def _to_async_generator(): async def _to_async_generator():
stream = self._get_client().completions.create(**params) if request.response_format is not None:
stream = self._get_client().chat.completions.create(**params)
else:
stream = self._get_client().completions.create(**params)
for chunk in stream: for chunk in stream:
yield chunk yield chunk
stream = _to_async_generator() stream = _to_async_generator()
async for chunk in process_completion_stream_response(stream): if request.response_format is not None:
yield chunk async for chunk in process_chat_completion_stream_response(
stream, request):
yield chunk
else:
async for chunk in process_completion_stream_response(stream):
yield chunk
# #
# CHAT COMPLETION # CHAT COMPLETION
@ -187,8 +242,7 @@ class CentMLInferenceAdapter(ModelRegistryHelper, Inference,
tools=tools or [], tools=tools or [],
tool_choice=tool_choice, tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format, tool_prompt_format=tool_prompt_format,
# Completions.create() got an unexpected keyword argument 'response_format' response_format=response_format,
#response_format=response_format,
stream=stream, stream=stream,
logprobs=logprobs, logprobs=logprobs,
) )
@ -200,15 +254,25 @@ class CentMLInferenceAdapter(ModelRegistryHelper, Inference,
async def _nonstream_chat_completion( async def _nonstream_chat_completion(
self, request: ChatCompletionRequest) -> ChatCompletionResponse: self, request: ChatCompletionRequest) -> ChatCompletionResponse:
params = await self._get_params(request) params = await self._get_params(request)
# For chat requests, if "messages" is in params -> .chat.completions
if "messages" in params: if "messages" in params:
response = self._get_client().chat.completions.create(**params) response = self._get_client().chat.completions.create(**params)
else: else:
# fallback if we ended up only with "prompt"
response = self._get_client().completions.create(**params) response = self._get_client().completions.create(**params)
result = process_chat_completion_response(response, request)
return process_chat_completion_response(response, request) if request.response_format is not None:
if isinstance(result.completion_message, dict):
content = result.completion_message.get("content")
if isinstance(content, list):
result.completion_message["content"] = "".join(content)
else:
if isinstance(result.completion_message.content, list):
updated_msg = result.completion_message.copy(update={
"content":
"".join(result.completion_message.content)
})
result = result.copy(
update={"completion_message": updated_msg})
return result
async def _stream_chat_completion( async def _stream_chat_completion(
self, request: ChatCompletionRequest) -> AsyncGenerator: self, request: ChatCompletionRequest) -> AsyncGenerator:
@ -237,18 +301,32 @@ class CentMLInferenceAdapter(ModelRegistryHelper, Inference,
input_dict = {} input_dict = {}
media_present = request_has_media(request) media_present = request_has_media(request)
llama_model = self.get_llama_model(request.model) llama_model = self.get_llama_model(request.model)
if isinstance(request, ChatCompletionRequest): if request.response_format is not None:
if media_present or not llama_model: if isinstance(request, ChatCompletionRequest):
input_dict["messages"] = [ input_dict["messages"] = [
await convert_message_to_openai_dict(m) await convert_message_to_openai_dict(m)
for m in request.messages for m in request.messages
] ]
else: else:
input_dict["prompt"] = await chat_completion_request_to_prompt( prompt_str = await completion_request_to_prompt(request)
request, llama_model) input_dict["messages"] = [{
"role": "user",
"content": prompt_str
}]
else: else:
input_dict["prompt"] = await completion_request_to_prompt(request) if isinstance(request, ChatCompletionRequest):
if media_present or not llama_model:
input_dict["messages"] = [
await convert_message_to_openai_dict(m)
for m in request.messages
]
else:
input_dict[
"prompt"] = await chat_completion_request_to_prompt(
request, llama_model)
else:
input_dict["prompt"] = await completion_request_to_prompt(
request)
params = { params = {
"model": "model":
request.model, request.model,
@ -263,26 +341,25 @@ class CentMLInferenceAdapter(ModelRegistryHelper, Inference,
self, self,
sampling_params: Optional[SamplingParams], sampling_params: Optional[SamplingParams],
logprobs: Optional[LogProbConfig], logprobs: Optional[LogProbConfig],
fmt: ResponseFormat, fmt: Optional[ResponseFormat],
) -> dict: ) -> dict:
options = get_sampling_options(sampling_params) options = get_sampling_options(sampling_params)
if fmt: if fmt:
if fmt.type == ResponseFormatType.json_schema.value: if fmt.type == ResponseFormatType.json_schema.value:
options["response_format"] = { options["response_format"] = {
"type": "json_object", "type": "json_schema",
# CentML currently does not support guided decoding, "json_schema": {
# the following setting is currently ignored by the server. "name": "schema",
"schema": fmt.json_schema, "schema": fmt.json_schema
},
} }
elif fmt.type == ResponseFormatType.grammar.value: elif fmt.type == ResponseFormatType.grammar.value:
raise NotImplementedError( raise NotImplementedError(
"Grammar response format not supported yet") "Grammar response format not supported yet")
else: else:
raise ValueError(f"Unknown response format {fmt.type}") raise ValueError(f"Unknown response format {fmt.type}")
if logprobs and logprobs.top_k: if logprobs and logprobs.top_k:
options["logprobs"] = logprobs.top_k options["logprobs"] = logprobs.top_k
return options return options
# #
@ -301,7 +378,6 @@ class CentMLInferenceAdapter(ModelRegistryHelper, Inference,
# CentML does not support media for embeddings. # CentML does not support media for embeddings.
assert all(not content_has_media(c) for c in contents), ( assert all(not content_has_media(c) for c in contents), (
"CentML does not support media for embeddings") "CentML does not support media for embeddings")
resp = self._get_client().embeddings.create( resp = self._get_client().embeddings.create(
model=model.provider_resource_id, model=model.provider_resource_id,
input=[interleaved_content_as_str(c) for c in contents], input=[interleaved_content_as_str(c) for c in contents],