mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-03 01:03:59 +00:00
Support Tooling
This commit is contained in:
parent
efd842d605
commit
c0d9b81253
4 changed files with 151 additions and 20 deletions
|
@ -23,11 +23,12 @@ from llama_stack.providers.utils.inference.model_registry import (
|
||||||
ModelRegistryHelper,
|
ModelRegistryHelper,
|
||||||
)
|
)
|
||||||
from llama_stack.providers.utils.inference.openai_compat import (
|
from llama_stack.providers.utils.inference.openai_compat import (
|
||||||
process_chat_completion_response,
|
|
||||||
process_chat_completion_stream_response,
|
process_chat_completion_stream_response,
|
||||||
)
|
)
|
||||||
|
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import convert_message_to_dict
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
|
convert_image_media_to_url,
|
||||||
|
)
|
||||||
|
|
||||||
from .config import SambaNovaImplConfig
|
from .config import SambaNovaImplConfig
|
||||||
|
|
||||||
|
@ -69,6 +70,7 @@ class SambaNovaInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
self,
|
self,
|
||||||
model_aliases=MODEL_ALIASES,
|
model_aliases=MODEL_ALIASES,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.config = config
|
self.config = config
|
||||||
self.formatter = ChatFormat(Tokenizer.get_instance())
|
self.formatter = ChatFormat(Tokenizer.get_instance())
|
||||||
|
|
||||||
|
@ -118,24 +120,38 @@ class SambaNovaInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
)
|
)
|
||||||
request_sambanova = await self.convert_chat_completion_request(request)
|
request_sambanova = await self.convert_chat_completion_request(request)
|
||||||
|
|
||||||
client = OpenAI(base_url=self.config.url, api_key=self.config.api_key)
|
|
||||||
if stream:
|
if stream:
|
||||||
return self._stream_chat_completion(request_sambanova, client)
|
return self._stream_chat_completion(request_sambanova)
|
||||||
else:
|
else:
|
||||||
return await self._nonstream_chat_completion(request_sambanova, client)
|
return await self._nonstream_chat_completion(request_sambanova)
|
||||||
|
|
||||||
async def _nonstream_chat_completion(
|
async def _nonstream_chat_completion(
|
||||||
self, request: ChatCompletionRequest, client: OpenAI
|
self, request: ChatCompletionRequest
|
||||||
) -> ChatCompletionResponse:
|
) -> ChatCompletionResponse:
|
||||||
r = client.chat.completions.create(**request)
|
response = self._get_client().chat.completions.create(**request)
|
||||||
return process_chat_completion_response(r, self.formatter)
|
choice = response.choices[0]
|
||||||
|
|
||||||
|
result = ChatCompletionResponse(
|
||||||
|
completion_message=CompletionMessage(
|
||||||
|
content=choice.message.content or "",
|
||||||
|
stop_reason=self.convert_to_sambanova_finish_reason(
|
||||||
|
choice.finish_reason
|
||||||
|
),
|
||||||
|
tool_calls=self.convert_to_sambanova_tool_calls(
|
||||||
|
choice.message.tool_calls
|
||||||
|
),
|
||||||
|
),
|
||||||
|
logprobs=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
async def _stream_chat_completion(
|
async def _stream_chat_completion(
|
||||||
self, request: ChatCompletionRequest, client: OpenAI
|
self, request: ChatCompletionRequest
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
async def _to_async_generator():
|
async def _to_async_generator():
|
||||||
s = client.chat.completions.create(**request)
|
streaming = self._get_client().chat.completions.create(**request)
|
||||||
for chunk in s:
|
for chunk in streaming:
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
stream = _to_async_generator()
|
stream = _to_async_generator()
|
||||||
|
@ -156,7 +172,7 @@ class SambaNovaInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
) -> dict:
|
) -> dict:
|
||||||
compatible_request = self.convert_sampling_params(request.sampling_params)
|
compatible_request = self.convert_sampling_params(request.sampling_params)
|
||||||
compatible_request["model"] = request.model
|
compatible_request["model"] = request.model
|
||||||
compatible_request["messages"] = await self.convert_to_sambanova_message(
|
compatible_request["messages"] = await self.convert_to_sambanova_messages(
|
||||||
request.messages
|
request.messages
|
||||||
)
|
)
|
||||||
compatible_request["stream"] = request.stream
|
compatible_request["stream"] = request.stream
|
||||||
|
@ -164,6 +180,7 @@ class SambaNovaInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
compatible_request["extra_headers"] = {
|
compatible_request["extra_headers"] = {
|
||||||
b"User-Agent": b"llama-stack: sambanova-inference-adapter",
|
b"User-Agent": b"llama-stack: sambanova-inference-adapter",
|
||||||
}
|
}
|
||||||
|
compatible_request["tools"] = self.convert_to_sambanova_tool(request.tools)
|
||||||
return compatible_request
|
return compatible_request
|
||||||
|
|
||||||
def convert_sampling_params(
|
def convert_sampling_params(
|
||||||
|
@ -189,12 +206,15 @@ class SambaNovaInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
|
|
||||||
return params
|
return params
|
||||||
|
|
||||||
async def convert_to_sambanova_message(self, messages: List[Message]) -> List[dict]:
|
async def convert_to_sambanova_messages(
|
||||||
|
self, messages: List[Message]
|
||||||
|
) -> List[dict]:
|
||||||
conversation = []
|
conversation = []
|
||||||
for message in messages:
|
for message in messages:
|
||||||
content = await convert_message_to_dict(message)
|
content = {}
|
||||||
|
|
||||||
|
content["content"] = await self.convert_to_sambanova_content(message)
|
||||||
|
|
||||||
# Need to override role
|
|
||||||
if isinstance(message, UserMessage):
|
if isinstance(message, UserMessage):
|
||||||
content["role"] = "user"
|
content["role"] = "user"
|
||||||
elif isinstance(message, CompletionMessage):
|
elif isinstance(message, CompletionMessage):
|
||||||
|
@ -221,3 +241,92 @@ class SambaNovaInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
conversation.append(content)
|
conversation.append(content)
|
||||||
|
|
||||||
return conversation
|
return conversation
|
||||||
|
|
||||||
|
async def convert_to_sambanova_content(self, message: Message) -> dict:
|
||||||
|
async def _convert_content(content) -> dict:
|
||||||
|
if isinstance(content, ImageMedia):
|
||||||
|
download = False
|
||||||
|
if isinstance(content, ImageMedia) and isinstance(content.image, URL):
|
||||||
|
download = content.image.uri.startswith("https://")
|
||||||
|
return {
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {
|
||||||
|
"url": await convert_image_media_to_url(
|
||||||
|
content, download=download
|
||||||
|
),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
assert isinstance(content, str)
|
||||||
|
return {"type": "text", "text": content}
|
||||||
|
|
||||||
|
if isinstance(message.content, list):
|
||||||
|
# If it is a list, the text content should be wrapped in dict
|
||||||
|
content = [await _convert_content(c) for c in message.content]
|
||||||
|
else:
|
||||||
|
content = message.content
|
||||||
|
|
||||||
|
return content
|
||||||
|
|
||||||
|
def convert_to_sambanova_tool(self, tools: List[ToolDefinition]) -> List[dict]:
|
||||||
|
if tools is None:
|
||||||
|
return tools
|
||||||
|
|
||||||
|
compatiable_tools = []
|
||||||
|
|
||||||
|
for tool in tools:
|
||||||
|
properties = {}
|
||||||
|
compatiable_required = []
|
||||||
|
if tool.parameters:
|
||||||
|
for tool_key, tool_param in tool.parameters.items():
|
||||||
|
properties[tool_key] = {"type": tool_param.param_type}
|
||||||
|
if tool_param.description:
|
||||||
|
properties[tool_key]["description"] = tool_param.description
|
||||||
|
if tool_param.default:
|
||||||
|
properties[tool_key]["default"] = tool_param.default
|
||||||
|
if tool_param.required:
|
||||||
|
compatiable_required.append(tool_key)
|
||||||
|
|
||||||
|
compatiable_tool = {
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": tool.tool_name,
|
||||||
|
"description": tool.description,
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": properties,
|
||||||
|
"required": compatiable_required,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
compatiable_tools.append(compatiable_tool)
|
||||||
|
|
||||||
|
if len(compatiable_tools) > 0:
|
||||||
|
return compatiable_tools
|
||||||
|
return None
|
||||||
|
|
||||||
|
def convert_to_sambanova_finish_reason(self, finish_reason: str) -> StopReason:
|
||||||
|
return {
|
||||||
|
"stop": StopReason.end_of_turn,
|
||||||
|
"length": StopReason.out_of_tokens,
|
||||||
|
"tool_calls": StopReason.end_of_message,
|
||||||
|
}.get(finish_reason, StopReason.end_of_turn)
|
||||||
|
|
||||||
|
def convert_to_sambanova_tool_calls(
|
||||||
|
self,
|
||||||
|
tool_calls,
|
||||||
|
) -> List[ToolCall]:
|
||||||
|
if not tool_calls:
|
||||||
|
return []
|
||||||
|
|
||||||
|
compitable_tool_calls = [
|
||||||
|
ToolCall(
|
||||||
|
call_id=call.id,
|
||||||
|
tool_name=call.function.name,
|
||||||
|
arguments=call.function.arguments,
|
||||||
|
)
|
||||||
|
for call in tool_calls
|
||||||
|
]
|
||||||
|
|
||||||
|
return compitable_tool_calls
|
||||||
|
|
|
@ -20,6 +20,7 @@ from llama_stack.providers.remote.inference.bedrock import BedrockConfig
|
||||||
from llama_stack.providers.remote.inference.fireworks import FireworksImplConfig
|
from llama_stack.providers.remote.inference.fireworks import FireworksImplConfig
|
||||||
from llama_stack.providers.remote.inference.nvidia import NVIDIAConfig
|
from llama_stack.providers.remote.inference.nvidia import NVIDIAConfig
|
||||||
from llama_stack.providers.remote.inference.ollama import OllamaImplConfig
|
from llama_stack.providers.remote.inference.ollama import OllamaImplConfig
|
||||||
|
from llama_stack.providers.remote.inference.sambanova import SambaNovaImplConfig
|
||||||
from llama_stack.providers.remote.inference.tgi import TGIImplConfig
|
from llama_stack.providers.remote.inference.tgi import TGIImplConfig
|
||||||
from llama_stack.providers.remote.inference.together import TogetherImplConfig
|
from llama_stack.providers.remote.inference.together import TogetherImplConfig
|
||||||
from llama_stack.providers.remote.inference.vllm import VLLMInferenceAdapterConfig
|
from llama_stack.providers.remote.inference.vllm import VLLMInferenceAdapterConfig
|
||||||
|
@ -173,6 +174,24 @@ def inference_tgi() -> ProviderFixture:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def inference_sambanova() -> ProviderFixture:
|
||||||
|
return ProviderFixture(
|
||||||
|
providers=[
|
||||||
|
Provider(
|
||||||
|
provider_id="sambanova",
|
||||||
|
provider_type="remote::sambanova",
|
||||||
|
config=SambaNovaImplConfig(
|
||||||
|
api_key=get_env_or_fail("SAMBANOVA_API_KEY"),
|
||||||
|
).model_dump(),
|
||||||
|
)
|
||||||
|
],
|
||||||
|
provider_data=dict(
|
||||||
|
sambanova_api_key=get_env_or_fail("SAMBANOVA_API_KEY"),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_model_short_name(model_name: str) -> str:
|
def get_model_short_name(model_name: str) -> str:
|
||||||
"""Convert model name to a short test identifier.
|
"""Convert model name to a short test identifier.
|
||||||
|
|
||||||
|
@ -208,6 +227,7 @@ INFERENCE_FIXTURES = [
|
||||||
"bedrock",
|
"bedrock",
|
||||||
"nvidia",
|
"nvidia",
|
||||||
"tgi",
|
"tgi",
|
||||||
|
"sambanova",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -24,7 +24,7 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
|
||||||
UserMessage(content=content),
|
UserMessage(content=content),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
messages = chat_completion_request_to_messages(request)
|
messages = chat_completion_request_to_messages(request, MODEL)
|
||||||
self.assertEqual(len(messages), 2)
|
self.assertEqual(len(messages), 2)
|
||||||
self.assertEqual(messages[-1].content, content)
|
self.assertEqual(messages[-1].content, content)
|
||||||
self.assertTrue("Cutting Knowledge Date: December 2023" in messages[0].content)
|
self.assertTrue("Cutting Knowledge Date: December 2023" in messages[0].content)
|
||||||
|
@ -41,7 +41,7 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
|
||||||
ToolDefinition(tool_name=BuiltinTool.brave_search),
|
ToolDefinition(tool_name=BuiltinTool.brave_search),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
messages = chat_completion_request_to_messages(request)
|
messages = chat_completion_request_to_messages(request, MODEL)
|
||||||
self.assertEqual(len(messages), 2)
|
self.assertEqual(len(messages), 2)
|
||||||
self.assertEqual(messages[-1].content, content)
|
self.assertEqual(messages[-1].content, content)
|
||||||
self.assertTrue("Cutting Knowledge Date: December 2023" in messages[0].content)
|
self.assertTrue("Cutting Knowledge Date: December 2023" in messages[0].content)
|
||||||
|
@ -69,7 +69,7 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
|
||||||
],
|
],
|
||||||
tool_prompt_format=ToolPromptFormat.json,
|
tool_prompt_format=ToolPromptFormat.json,
|
||||||
)
|
)
|
||||||
messages = chat_completion_request_to_messages(request)
|
messages = chat_completion_request_to_messages(request, MODEL)
|
||||||
self.assertEqual(len(messages), 3)
|
self.assertEqual(len(messages), 3)
|
||||||
self.assertTrue("Environment: ipython" in messages[0].content)
|
self.assertTrue("Environment: ipython" in messages[0].content)
|
||||||
|
|
||||||
|
@ -99,7 +99,7 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
messages = chat_completion_request_to_messages(request)
|
messages = chat_completion_request_to_messages(request, MODEL)
|
||||||
self.assertEqual(len(messages), 3)
|
self.assertEqual(len(messages), 3)
|
||||||
|
|
||||||
self.assertTrue("Environment: ipython" in messages[0].content)
|
self.assertTrue("Environment: ipython" in messages[0].content)
|
||||||
|
@ -121,7 +121,7 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
|
||||||
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
|
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
messages = chat_completion_request_to_messages(request)
|
messages = chat_completion_request_to_messages(request, MODEL)
|
||||||
self.assertEqual(len(messages), 2, messages)
|
self.assertEqual(len(messages), 2, messages)
|
||||||
self.assertTrue(messages[0].content.endswith(system_prompt))
|
self.assertTrue(messages[0].content.endswith(system_prompt))
|
||||||
|
|
||||||
|
|
|
@ -49,6 +49,7 @@ class TestVisionModelInference:
|
||||||
"remote::fireworks",
|
"remote::fireworks",
|
||||||
"remote::ollama",
|
"remote::ollama",
|
||||||
"remote::vllm",
|
"remote::vllm",
|
||||||
|
"remote::sambanova",
|
||||||
):
|
):
|
||||||
pytest.skip(
|
pytest.skip(
|
||||||
"Other inference providers don't support vision chat completion() yet"
|
"Other inference providers don't support vision chat completion() yet"
|
||||||
|
@ -83,6 +84,7 @@ class TestVisionModelInference:
|
||||||
"remote::fireworks",
|
"remote::fireworks",
|
||||||
"remote::ollama",
|
"remote::ollama",
|
||||||
"remote::vllm",
|
"remote::vllm",
|
||||||
|
"remote::sambanova",
|
||||||
):
|
):
|
||||||
pytest.skip(
|
pytest.skip(
|
||||||
"Other inference providers don't support vision chat completion() yet"
|
"Other inference providers don't support vision chat completion() yet"
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue