Support Tooling

This commit is contained in:
Edward Ma 2024-12-02 13:38:54 -08:00
parent efd842d605
commit c0d9b81253
4 changed files with 151 additions and 20 deletions

View file

@ -23,11 +23,12 @@ from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
)
from llama_stack.providers.utils.inference.openai_compat import (
process_chat_completion_response,
process_chat_completion_stream_response,
)
from llama_stack.providers.utils.inference.prompt_adapter import convert_message_to_dict
from llama_stack.providers.utils.inference.prompt_adapter import (
convert_image_media_to_url,
)
from .config import SambaNovaImplConfig
@ -69,6 +70,7 @@ class SambaNovaInferenceAdapter(ModelRegistryHelper, Inference):
self,
model_aliases=MODEL_ALIASES,
)
self.config = config
self.formatter = ChatFormat(Tokenizer.get_instance())
@ -118,24 +120,38 @@ class SambaNovaInferenceAdapter(ModelRegistryHelper, Inference):
)
request_sambanova = await self.convert_chat_completion_request(request)
client = OpenAI(base_url=self.config.url, api_key=self.config.api_key)
if stream:
return self._stream_chat_completion(request_sambanova, client)
return self._stream_chat_completion(request_sambanova)
else:
return await self._nonstream_chat_completion(request_sambanova, client)
return await self._nonstream_chat_completion(request_sambanova)
async def _nonstream_chat_completion(
self, request: ChatCompletionRequest, client: OpenAI
self, request: ChatCompletionRequest
) -> ChatCompletionResponse:
r = client.chat.completions.create(**request)
return process_chat_completion_response(r, self.formatter)
response = self._get_client().chat.completions.create(**request)
choice = response.choices[0]
result = ChatCompletionResponse(
completion_message=CompletionMessage(
content=choice.message.content or "",
stop_reason=self.convert_to_sambanova_finish_reason(
choice.finish_reason
),
tool_calls=self.convert_to_sambanova_tool_calls(
choice.message.tool_calls
),
),
logprobs=None,
)
return result
async def _stream_chat_completion(
self, request: ChatCompletionRequest, client: OpenAI
self, request: ChatCompletionRequest
) -> AsyncGenerator:
async def _to_async_generator():
s = client.chat.completions.create(**request)
for chunk in s:
streaming = self._get_client().chat.completions.create(**request)
for chunk in streaming:
yield chunk
stream = _to_async_generator()
@ -156,7 +172,7 @@ class SambaNovaInferenceAdapter(ModelRegistryHelper, Inference):
) -> dict:
compatible_request = self.convert_sampling_params(request.sampling_params)
compatible_request["model"] = request.model
compatible_request["messages"] = await self.convert_to_sambanova_message(
compatible_request["messages"] = await self.convert_to_sambanova_messages(
request.messages
)
compatible_request["stream"] = request.stream
@ -164,6 +180,7 @@ class SambaNovaInferenceAdapter(ModelRegistryHelper, Inference):
compatible_request["extra_headers"] = {
b"User-Agent": b"llama-stack: sambanova-inference-adapter",
}
compatible_request["tools"] = self.convert_to_sambanova_tool(request.tools)
return compatible_request
def convert_sampling_params(
@ -189,12 +206,15 @@ class SambaNovaInferenceAdapter(ModelRegistryHelper, Inference):
return params
async def convert_to_sambanova_message(self, messages: List[Message]) -> List[dict]:
async def convert_to_sambanova_messages(
self, messages: List[Message]
) -> List[dict]:
conversation = []
for message in messages:
content = await convert_message_to_dict(message)
content = {}
content["content"] = await self.convert_to_sambanova_content(message)
# Need to override role
if isinstance(message, UserMessage):
content["role"] = "user"
elif isinstance(message, CompletionMessage):
@ -221,3 +241,92 @@ class SambaNovaInferenceAdapter(ModelRegistryHelper, Inference):
conversation.append(content)
return conversation
async def convert_to_sambanova_content(self, message: Message) -> dict:
async def _convert_content(content) -> dict:
if isinstance(content, ImageMedia):
download = False
if isinstance(content, ImageMedia) and isinstance(content.image, URL):
download = content.image.uri.startswith("https://")
return {
"type": "image_url",
"image_url": {
"url": await convert_image_media_to_url(
content, download=download
),
},
}
else:
assert isinstance(content, str)
return {"type": "text", "text": content}
if isinstance(message.content, list):
# If it is a list, the text content should be wrapped in dict
content = [await _convert_content(c) for c in message.content]
else:
content = message.content
return content
def convert_to_sambanova_tool(self, tools: List[ToolDefinition]) -> List[dict]:
if tools is None:
return tools
compatiable_tools = []
for tool in tools:
properties = {}
compatiable_required = []
if tool.parameters:
for tool_key, tool_param in tool.parameters.items():
properties[tool_key] = {"type": tool_param.param_type}
if tool_param.description:
properties[tool_key]["description"] = tool_param.description
if tool_param.default:
properties[tool_key]["default"] = tool_param.default
if tool_param.required:
compatiable_required.append(tool_key)
compatiable_tool = {
"type": "function",
"function": {
"name": tool.tool_name,
"description": tool.description,
"parameters": {
"type": "object",
"properties": properties,
"required": compatiable_required,
},
},
}
compatiable_tools.append(compatiable_tool)
if len(compatiable_tools) > 0:
return compatiable_tools
return None
def convert_to_sambanova_finish_reason(self, finish_reason: str) -> StopReason:
return {
"stop": StopReason.end_of_turn,
"length": StopReason.out_of_tokens,
"tool_calls": StopReason.end_of_message,
}.get(finish_reason, StopReason.end_of_turn)
def convert_to_sambanova_tool_calls(
self,
tool_calls,
) -> List[ToolCall]:
if not tool_calls:
return []
compitable_tool_calls = [
ToolCall(
call_id=call.id,
tool_name=call.function.name,
arguments=call.function.arguments,
)
for call in tool_calls
]
return compitable_tool_calls

View file

@ -20,6 +20,7 @@ from llama_stack.providers.remote.inference.bedrock import BedrockConfig
from llama_stack.providers.remote.inference.fireworks import FireworksImplConfig
from llama_stack.providers.remote.inference.nvidia import NVIDIAConfig
from llama_stack.providers.remote.inference.ollama import OllamaImplConfig
from llama_stack.providers.remote.inference.sambanova import SambaNovaImplConfig
from llama_stack.providers.remote.inference.tgi import TGIImplConfig
from llama_stack.providers.remote.inference.together import TogetherImplConfig
from llama_stack.providers.remote.inference.vllm import VLLMInferenceAdapterConfig
@ -173,6 +174,24 @@ def inference_tgi() -> ProviderFixture:
)
@pytest.fixture(scope="session")
def inference_sambanova() -> ProviderFixture:
return ProviderFixture(
providers=[
Provider(
provider_id="sambanova",
provider_type="remote::sambanova",
config=SambaNovaImplConfig(
api_key=get_env_or_fail("SAMBANOVA_API_KEY"),
).model_dump(),
)
],
provider_data=dict(
sambanova_api_key=get_env_or_fail("SAMBANOVA_API_KEY"),
),
)
def get_model_short_name(model_name: str) -> str:
"""Convert model name to a short test identifier.
@ -208,6 +227,7 @@ INFERENCE_FIXTURES = [
"bedrock",
"nvidia",
"tgi",
"sambanova",
]

View file

@ -24,7 +24,7 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
UserMessage(content=content),
],
)
messages = chat_completion_request_to_messages(request)
messages = chat_completion_request_to_messages(request, MODEL)
self.assertEqual(len(messages), 2)
self.assertEqual(messages[-1].content, content)
self.assertTrue("Cutting Knowledge Date: December 2023" in messages[0].content)
@ -41,7 +41,7 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
ToolDefinition(tool_name=BuiltinTool.brave_search),
],
)
messages = chat_completion_request_to_messages(request)
messages = chat_completion_request_to_messages(request, MODEL)
self.assertEqual(len(messages), 2)
self.assertEqual(messages[-1].content, content)
self.assertTrue("Cutting Knowledge Date: December 2023" in messages[0].content)
@ -69,7 +69,7 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
],
tool_prompt_format=ToolPromptFormat.json,
)
messages = chat_completion_request_to_messages(request)
messages = chat_completion_request_to_messages(request, MODEL)
self.assertEqual(len(messages), 3)
self.assertTrue("Environment: ipython" in messages[0].content)
@ -99,7 +99,7 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
),
],
)
messages = chat_completion_request_to_messages(request)
messages = chat_completion_request_to_messages(request, MODEL)
self.assertEqual(len(messages), 3)
self.assertTrue("Environment: ipython" in messages[0].content)
@ -121,7 +121,7 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
],
)
messages = chat_completion_request_to_messages(request)
messages = chat_completion_request_to_messages(request, MODEL)
self.assertEqual(len(messages), 2, messages)
self.assertTrue(messages[0].content.endswith(system_prompt))

View file

@ -49,6 +49,7 @@ class TestVisionModelInference:
"remote::fireworks",
"remote::ollama",
"remote::vllm",
"remote::sambanova",
):
pytest.skip(
"Other inference providers don't support vision chat completion() yet"
@ -83,6 +84,7 @@ class TestVisionModelInference:
"remote::fireworks",
"remote::ollama",
"remote::vllm",
"remote::sambanova",
):
pytest.skip(
"Other inference providers don't support vision chat completion() yet"