forked from phoenix-oss/llama-stack-mirror
# What does this PR do? Before this change, `distro_codegen.py` would only work if the user manually installed multiple provider-specific dependencies (see #1122). Now, users can run `distro_codegen.py` without any provider-specific dependencies because we avoid importing the entire provider implementations just to get the config needed to build the provider template. Concretely, this mostly means moving the MODEL_ALIASES (and related variants) definitions to a new models.py class within the provider implementation for those providers that require additional dependencies. It also meant moving a couple of imports from top-level imports to inside `get_adapter_impl` for some providers, which follows the pattern used by multiple existing providers. To ensure we don't regress and accidentally add new imports that cause distro_codegen.py to fail, the stubbed-in pre-commit hook for distro_codegen.py was uncommented and slightly tweaked to run via `uv run python ...` to ensure it runs with only the project's default dependencies and to run automatically instead of manually. Lastly, this updates distro_codegen.py itself to keep track of paths it might have changed and to only `git diff` those specific paths when checking for changed files instead of doing a diff on the entire working tree. The latter was overly broad and would require a user have no other unstaged changes in their working tree, even if those unstaged changes were unrelated to generated code. Now it only flags uncommitted changes for paths distro_codegen.py actually writes to. Our generated code was also out-of-date, presumably because of these issues, so this commit also has some updates to the generated code purely because it was out of sync, and the pre-commit hook now enforces things to be updated. (Closes #1122) ## Test Plan I manually tested distro_codegen.py and the pre-commit hook to verify those work as expected, flagging any uncommited changes and catching any imports that attempt to pull in provider-specific dependencies. However, I do not have valid api keys to the impacted provider implementations, and am unable to easily run the inference tests against each changed provider. There are no functional changes to the provider implementations here, but I'd appreciate a second set of eyes on the changed import statements and moving of MODEL_ALIASES type code to a separate models.py to ensure I didn't make any obvious errors. --------- Signed-off-by: Ben Browning <bbrownin@redhat.com> Co-authored-by: Ashwin Bharambe <ashwin.bharambe@gmail.com>
288 lines
10 KiB
Python
288 lines
10 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
import json
|
|
from typing import AsyncGenerator
|
|
|
|
from llama_models.llama3.api.chat_format import ChatFormat
|
|
from llama_models.llama3.api.tokenizer import Tokenizer
|
|
from openai import OpenAI
|
|
|
|
from llama_stack.apis.common.content_types import (
|
|
ImageContentItem,
|
|
InterleavedContent,
|
|
TextContentItem,
|
|
)
|
|
from llama_stack.apis.inference import * # noqa: F403
|
|
from llama_stack.models.llama.datatypes import (
|
|
GreedySamplingStrategy,
|
|
TopKSamplingStrategy,
|
|
TopPSamplingStrategy,
|
|
)
|
|
from llama_stack.providers.utils.inference.model_registry import (
|
|
ModelRegistryHelper,
|
|
)
|
|
from llama_stack.providers.utils.inference.openai_compat import (
|
|
process_chat_completion_stream_response,
|
|
)
|
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
|
convert_image_content_to_url,
|
|
)
|
|
|
|
from .config import SambaNovaImplConfig
|
|
from .models import MODEL_ALIASES
|
|
|
|
|
|
class SambaNovaInferenceAdapter(ModelRegistryHelper, Inference):
|
|
def __init__(self, config: SambaNovaImplConfig) -> None:
|
|
ModelRegistryHelper.__init__(
|
|
self,
|
|
model_aliases=MODEL_ALIASES,
|
|
)
|
|
|
|
self.config = config
|
|
self.formatter = ChatFormat(Tokenizer.get_instance())
|
|
|
|
async def initialize(self) -> None:
|
|
return
|
|
|
|
async def shutdown(self) -> None:
|
|
pass
|
|
|
|
def _get_client(self) -> OpenAI:
|
|
return OpenAI(base_url=self.config.url, api_key=self.config.api_key)
|
|
|
|
async def completion(
|
|
self,
|
|
model_id: str,
|
|
content: InterleavedContent,
|
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
|
response_format: Optional[ResponseFormat] = None,
|
|
stream: Optional[bool] = False,
|
|
logprobs: Optional[LogProbConfig] = None,
|
|
) -> AsyncGenerator:
|
|
raise NotImplementedError()
|
|
|
|
async def chat_completion(
|
|
self,
|
|
model_id: str,
|
|
messages: List[Message],
|
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
|
response_format: Optional[ResponseFormat] = None,
|
|
tools: Optional[List[ToolDefinition]] = None,
|
|
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
|
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
|
stream: Optional[bool] = False,
|
|
tool_config: Optional[ToolConfig] = None,
|
|
logprobs: Optional[LogProbConfig] = None,
|
|
) -> AsyncGenerator:
|
|
model = await self.model_store.get_model(model_id)
|
|
|
|
request = ChatCompletionRequest(
|
|
model=model.provider_resource_id,
|
|
messages=messages,
|
|
sampling_params=sampling_params,
|
|
tools=tools or [],
|
|
stream=stream,
|
|
logprobs=logprobs,
|
|
tool_config=tool_config,
|
|
)
|
|
request_sambanova = await self.convert_chat_completion_request(request)
|
|
|
|
if stream:
|
|
return self._stream_chat_completion(request_sambanova)
|
|
else:
|
|
return await self._nonstream_chat_completion(request_sambanova)
|
|
|
|
async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
|
|
response = self._get_client().chat.completions.create(**request)
|
|
|
|
choice = response.choices[0]
|
|
|
|
result = ChatCompletionResponse(
|
|
completion_message=CompletionMessage(
|
|
content=choice.message.content or "",
|
|
stop_reason=self.convert_to_sambanova_finish_reason(choice.finish_reason),
|
|
tool_calls=self.convert_to_sambanova_tool_calls(choice.message.tool_calls),
|
|
),
|
|
logprobs=None,
|
|
)
|
|
|
|
return result
|
|
|
|
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
|
|
async def _to_async_generator():
|
|
streaming = self._get_client().chat.completions.create(**request)
|
|
for chunk in streaming:
|
|
yield chunk
|
|
|
|
stream = _to_async_generator()
|
|
async for chunk in process_chat_completion_stream_response(stream, self.formatter, request):
|
|
yield chunk
|
|
|
|
async def embeddings(
|
|
self,
|
|
model_id: str,
|
|
contents: List[InterleavedContent],
|
|
) -> EmbeddingsResponse:
|
|
raise NotImplementedError()
|
|
|
|
async def convert_chat_completion_request(self, request: ChatCompletionRequest) -> dict:
|
|
compatible_request = self.convert_sampling_params(request.sampling_params)
|
|
compatible_request["model"] = request.model
|
|
compatible_request["messages"] = await self.convert_to_sambanova_messages(request.messages)
|
|
compatible_request["stream"] = request.stream
|
|
compatible_request["logprobs"] = False
|
|
compatible_request["extra_headers"] = {
|
|
b"User-Agent": b"llama-stack: sambanova-inference-adapter",
|
|
}
|
|
compatible_request["tools"] = self.convert_to_sambanova_tool(request.tools)
|
|
return compatible_request
|
|
|
|
def convert_sampling_params(self, sampling_params: SamplingParams, legacy: bool = False) -> dict:
|
|
params = {}
|
|
|
|
if sampling_params:
|
|
params["frequency_penalty"] = sampling_params.repetition_penalty
|
|
|
|
if sampling_params.max_tokens:
|
|
if legacy:
|
|
params["max_tokens"] = sampling_params.max_tokens
|
|
else:
|
|
params["max_completion_tokens"] = sampling_params.max_tokens
|
|
|
|
if isinstance(sampling_params.strategy, TopPSamplingStrategy):
|
|
params["top_p"] = sampling_params.strategy.top_p
|
|
if isinstance(sampling_params.strategy, TopKSamplingStrategy):
|
|
params["extra_body"]["top_k"] = sampling_params.strategy.top_k
|
|
if isinstance(sampling_params.strategy, GreedySamplingStrategy):
|
|
params["temperature"] = 0.0
|
|
|
|
return params
|
|
|
|
async def convert_to_sambanova_messages(self, messages: List[Message]) -> List[dict]:
|
|
conversation = []
|
|
for message in messages:
|
|
content = {}
|
|
|
|
content["content"] = await self.convert_to_sambanova_content(message)
|
|
|
|
if isinstance(message, UserMessage):
|
|
content["role"] = "user"
|
|
elif isinstance(message, CompletionMessage):
|
|
content["role"] = "assistant"
|
|
tools = []
|
|
for tool_call in message.tool_calls:
|
|
tools.append(
|
|
{
|
|
"id": tool_call.call_id,
|
|
"function": {
|
|
"name": tool_call.name,
|
|
"arguments": json.dumps(tool_call.arguments),
|
|
},
|
|
"type": "function",
|
|
}
|
|
)
|
|
content["tool_calls"] = tools
|
|
elif isinstance(message, ToolResponseMessage):
|
|
content["role"] = "tool"
|
|
content["tool_call_id"] = message.call_id
|
|
elif isinstance(message, SystemMessage):
|
|
content["role"] = "system"
|
|
|
|
conversation.append(content)
|
|
|
|
return conversation
|
|
|
|
async def convert_to_sambanova_content(self, message: Message) -> dict:
|
|
async def _convert_content(content) -> dict:
|
|
if isinstance(content, ImageContentItem):
|
|
url = await convert_image_content_to_url(content, download=True)
|
|
# A fix to make sure the call sucess.
|
|
components = url.split(";base64")
|
|
url = f"{components[0].lower()};base64{components[1]}"
|
|
return {
|
|
"type": "image_url",
|
|
"image_url": {"url": url},
|
|
}
|
|
else:
|
|
text = content.text if isinstance(content, TextContentItem) else content
|
|
assert isinstance(text, str)
|
|
return {"type": "text", "text": text}
|
|
|
|
if isinstance(message.content, list):
|
|
# If it is a list, the text content should be wrapped in dict
|
|
content = [await _convert_content(c) for c in message.content]
|
|
else:
|
|
content = message.content
|
|
|
|
return content
|
|
|
|
def convert_to_sambanova_tool(self, tools: List[ToolDefinition]) -> List[dict]:
|
|
if tools is None:
|
|
return tools
|
|
|
|
compatiable_tools = []
|
|
|
|
for tool in tools:
|
|
properties = {}
|
|
compatiable_required = []
|
|
if tool.parameters:
|
|
for tool_key, tool_param in tool.parameters.items():
|
|
properties[tool_key] = {"type": tool_param.param_type}
|
|
if tool_param.description:
|
|
properties[tool_key]["description"] = tool_param.description
|
|
if tool_param.default:
|
|
properties[tool_key]["default"] = tool_param.default
|
|
if tool_param.required:
|
|
compatiable_required.append(tool_key)
|
|
|
|
compatiable_tool = {
|
|
"type": "function",
|
|
"function": {
|
|
"name": tool.tool_name,
|
|
"description": tool.description,
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": properties,
|
|
"required": compatiable_required,
|
|
},
|
|
},
|
|
}
|
|
|
|
compatiable_tools.append(compatiable_tool)
|
|
|
|
if len(compatiable_tools) > 0:
|
|
return compatiable_tools
|
|
return None
|
|
|
|
def convert_to_sambanova_finish_reason(self, finish_reason: str) -> StopReason:
|
|
return {
|
|
"stop": StopReason.end_of_turn,
|
|
"length": StopReason.out_of_tokens,
|
|
"tool_calls": StopReason.end_of_message,
|
|
}.get(finish_reason, StopReason.end_of_turn)
|
|
|
|
def convert_to_sambanova_tool_calls(
|
|
self,
|
|
tool_calls,
|
|
) -> List[ToolCall]:
|
|
if not tool_calls:
|
|
return []
|
|
|
|
for call in tool_calls:
|
|
call_function_arguments = json.loads(call.function.arguments)
|
|
|
|
compitable_tool_calls = [
|
|
ToolCall(
|
|
call_id=call.id,
|
|
tool_name=call.function.name,
|
|
arguments=call_function_arguments,
|
|
)
|
|
for call in tool_calls
|
|
]
|
|
|
|
return compitable_tool_calls
|