mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-02 08:44:44 +00:00
Support Tooling
This commit is contained in:
parent
efd842d605
commit
c0d9b81253
4 changed files with 151 additions and 20 deletions
|
@ -23,11 +23,12 @@ from llama_stack.providers.utils.inference.model_registry import (
|
|||
ModelRegistryHelper,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
process_chat_completion_response,
|
||||
process_chat_completion_stream_response,
|
||||
)
|
||||
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import convert_message_to_dict
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
convert_image_media_to_url,
|
||||
)
|
||||
|
||||
from .config import SambaNovaImplConfig
|
||||
|
||||
|
@ -69,6 +70,7 @@ class SambaNovaInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
self,
|
||||
model_aliases=MODEL_ALIASES,
|
||||
)
|
||||
|
||||
self.config = config
|
||||
self.formatter = ChatFormat(Tokenizer.get_instance())
|
||||
|
||||
|
@ -118,24 +120,38 @@ class SambaNovaInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
)
|
||||
request_sambanova = await self.convert_chat_completion_request(request)
|
||||
|
||||
client = OpenAI(base_url=self.config.url, api_key=self.config.api_key)
|
||||
if stream:
|
||||
return self._stream_chat_completion(request_sambanova, client)
|
||||
return self._stream_chat_completion(request_sambanova)
|
||||
else:
|
||||
return await self._nonstream_chat_completion(request_sambanova, client)
|
||||
return await self._nonstream_chat_completion(request_sambanova)
|
||||
|
||||
async def _nonstream_chat_completion(
|
||||
self, request: ChatCompletionRequest, client: OpenAI
|
||||
self, request: ChatCompletionRequest
|
||||
) -> ChatCompletionResponse:
|
||||
r = client.chat.completions.create(**request)
|
||||
return process_chat_completion_response(r, self.formatter)
|
||||
response = self._get_client().chat.completions.create(**request)
|
||||
choice = response.choices[0]
|
||||
|
||||
result = ChatCompletionResponse(
|
||||
completion_message=CompletionMessage(
|
||||
content=choice.message.content or "",
|
||||
stop_reason=self.convert_to_sambanova_finish_reason(
|
||||
choice.finish_reason
|
||||
),
|
||||
tool_calls=self.convert_to_sambanova_tool_calls(
|
||||
choice.message.tool_calls
|
||||
),
|
||||
),
|
||||
logprobs=None,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
async def _stream_chat_completion(
|
||||
self, request: ChatCompletionRequest, client: OpenAI
|
||||
self, request: ChatCompletionRequest
|
||||
) -> AsyncGenerator:
|
||||
async def _to_async_generator():
|
||||
s = client.chat.completions.create(**request)
|
||||
for chunk in s:
|
||||
streaming = self._get_client().chat.completions.create(**request)
|
||||
for chunk in streaming:
|
||||
yield chunk
|
||||
|
||||
stream = _to_async_generator()
|
||||
|
@ -156,7 +172,7 @@ class SambaNovaInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
) -> dict:
|
||||
compatible_request = self.convert_sampling_params(request.sampling_params)
|
||||
compatible_request["model"] = request.model
|
||||
compatible_request["messages"] = await self.convert_to_sambanova_message(
|
||||
compatible_request["messages"] = await self.convert_to_sambanova_messages(
|
||||
request.messages
|
||||
)
|
||||
compatible_request["stream"] = request.stream
|
||||
|
@ -164,6 +180,7 @@ class SambaNovaInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
compatible_request["extra_headers"] = {
|
||||
b"User-Agent": b"llama-stack: sambanova-inference-adapter",
|
||||
}
|
||||
compatible_request["tools"] = self.convert_to_sambanova_tool(request.tools)
|
||||
return compatible_request
|
||||
|
||||
def convert_sampling_params(
|
||||
|
@ -189,12 +206,15 @@ class SambaNovaInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
|
||||
return params
|
||||
|
||||
async def convert_to_sambanova_message(self, messages: List[Message]) -> List[dict]:
|
||||
async def convert_to_sambanova_messages(
|
||||
self, messages: List[Message]
|
||||
) -> List[dict]:
|
||||
conversation = []
|
||||
for message in messages:
|
||||
content = await convert_message_to_dict(message)
|
||||
content = {}
|
||||
|
||||
content["content"] = await self.convert_to_sambanova_content(message)
|
||||
|
||||
# Need to override role
|
||||
if isinstance(message, UserMessage):
|
||||
content["role"] = "user"
|
||||
elif isinstance(message, CompletionMessage):
|
||||
|
@ -221,3 +241,92 @@ class SambaNovaInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
conversation.append(content)
|
||||
|
||||
return conversation
|
||||
|
||||
async def convert_to_sambanova_content(self, message: Message) -> dict:
|
||||
async def _convert_content(content) -> dict:
|
||||
if isinstance(content, ImageMedia):
|
||||
download = False
|
||||
if isinstance(content, ImageMedia) and isinstance(content.image, URL):
|
||||
download = content.image.uri.startswith("https://")
|
||||
return {
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": await convert_image_media_to_url(
|
||||
content, download=download
|
||||
),
|
||||
},
|
||||
}
|
||||
else:
|
||||
assert isinstance(content, str)
|
||||
return {"type": "text", "text": content}
|
||||
|
||||
if isinstance(message.content, list):
|
||||
# If it is a list, the text content should be wrapped in dict
|
||||
content = [await _convert_content(c) for c in message.content]
|
||||
else:
|
||||
content = message.content
|
||||
|
||||
return content
|
||||
|
||||
def convert_to_sambanova_tool(self, tools: List[ToolDefinition]) -> List[dict]:
|
||||
if tools is None:
|
||||
return tools
|
||||
|
||||
compatiable_tools = []
|
||||
|
||||
for tool in tools:
|
||||
properties = {}
|
||||
compatiable_required = []
|
||||
if tool.parameters:
|
||||
for tool_key, tool_param in tool.parameters.items():
|
||||
properties[tool_key] = {"type": tool_param.param_type}
|
||||
if tool_param.description:
|
||||
properties[tool_key]["description"] = tool_param.description
|
||||
if tool_param.default:
|
||||
properties[tool_key]["default"] = tool_param.default
|
||||
if tool_param.required:
|
||||
compatiable_required.append(tool_key)
|
||||
|
||||
compatiable_tool = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool.tool_name,
|
||||
"description": tool.description,
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": properties,
|
||||
"required": compatiable_required,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
compatiable_tools.append(compatiable_tool)
|
||||
|
||||
if len(compatiable_tools) > 0:
|
||||
return compatiable_tools
|
||||
return None
|
||||
|
||||
def convert_to_sambanova_finish_reason(self, finish_reason: str) -> StopReason:
|
||||
return {
|
||||
"stop": StopReason.end_of_turn,
|
||||
"length": StopReason.out_of_tokens,
|
||||
"tool_calls": StopReason.end_of_message,
|
||||
}.get(finish_reason, StopReason.end_of_turn)
|
||||
|
||||
def convert_to_sambanova_tool_calls(
|
||||
self,
|
||||
tool_calls,
|
||||
) -> List[ToolCall]:
|
||||
if not tool_calls:
|
||||
return []
|
||||
|
||||
compitable_tool_calls = [
|
||||
ToolCall(
|
||||
call_id=call.id,
|
||||
tool_name=call.function.name,
|
||||
arguments=call.function.arguments,
|
||||
)
|
||||
for call in tool_calls
|
||||
]
|
||||
|
||||
return compitable_tool_calls
|
||||
|
|
|
@ -20,6 +20,7 @@ from llama_stack.providers.remote.inference.bedrock import BedrockConfig
|
|||
from llama_stack.providers.remote.inference.fireworks import FireworksImplConfig
|
||||
from llama_stack.providers.remote.inference.nvidia import NVIDIAConfig
|
||||
from llama_stack.providers.remote.inference.ollama import OllamaImplConfig
|
||||
from llama_stack.providers.remote.inference.sambanova import SambaNovaImplConfig
|
||||
from llama_stack.providers.remote.inference.tgi import TGIImplConfig
|
||||
from llama_stack.providers.remote.inference.together import TogetherImplConfig
|
||||
from llama_stack.providers.remote.inference.vllm import VLLMInferenceAdapterConfig
|
||||
|
@ -173,6 +174,24 @@ def inference_tgi() -> ProviderFixture:
|
|||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def inference_sambanova() -> ProviderFixture:
|
||||
return ProviderFixture(
|
||||
providers=[
|
||||
Provider(
|
||||
provider_id="sambanova",
|
||||
provider_type="remote::sambanova",
|
||||
config=SambaNovaImplConfig(
|
||||
api_key=get_env_or_fail("SAMBANOVA_API_KEY"),
|
||||
).model_dump(),
|
||||
)
|
||||
],
|
||||
provider_data=dict(
|
||||
sambanova_api_key=get_env_or_fail("SAMBANOVA_API_KEY"),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def get_model_short_name(model_name: str) -> str:
|
||||
"""Convert model name to a short test identifier.
|
||||
|
||||
|
@ -208,6 +227,7 @@ INFERENCE_FIXTURES = [
|
|||
"bedrock",
|
||||
"nvidia",
|
||||
"tgi",
|
||||
"sambanova",
|
||||
]
|
||||
|
||||
|
||||
|
|
|
@ -24,7 +24,7 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
|
|||
UserMessage(content=content),
|
||||
],
|
||||
)
|
||||
messages = chat_completion_request_to_messages(request)
|
||||
messages = chat_completion_request_to_messages(request, MODEL)
|
||||
self.assertEqual(len(messages), 2)
|
||||
self.assertEqual(messages[-1].content, content)
|
||||
self.assertTrue("Cutting Knowledge Date: December 2023" in messages[0].content)
|
||||
|
@ -41,7 +41,7 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
|
|||
ToolDefinition(tool_name=BuiltinTool.brave_search),
|
||||
],
|
||||
)
|
||||
messages = chat_completion_request_to_messages(request)
|
||||
messages = chat_completion_request_to_messages(request, MODEL)
|
||||
self.assertEqual(len(messages), 2)
|
||||
self.assertEqual(messages[-1].content, content)
|
||||
self.assertTrue("Cutting Knowledge Date: December 2023" in messages[0].content)
|
||||
|
@ -69,7 +69,7 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
|
|||
],
|
||||
tool_prompt_format=ToolPromptFormat.json,
|
||||
)
|
||||
messages = chat_completion_request_to_messages(request)
|
||||
messages = chat_completion_request_to_messages(request, MODEL)
|
||||
self.assertEqual(len(messages), 3)
|
||||
self.assertTrue("Environment: ipython" in messages[0].content)
|
||||
|
||||
|
@ -99,7 +99,7 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
|
|||
),
|
||||
],
|
||||
)
|
||||
messages = chat_completion_request_to_messages(request)
|
||||
messages = chat_completion_request_to_messages(request, MODEL)
|
||||
self.assertEqual(len(messages), 3)
|
||||
|
||||
self.assertTrue("Environment: ipython" in messages[0].content)
|
||||
|
@ -121,7 +121,7 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
|
|||
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
|
||||
],
|
||||
)
|
||||
messages = chat_completion_request_to_messages(request)
|
||||
messages = chat_completion_request_to_messages(request, MODEL)
|
||||
self.assertEqual(len(messages), 2, messages)
|
||||
self.assertTrue(messages[0].content.endswith(system_prompt))
|
||||
|
||||
|
|
|
@ -49,6 +49,7 @@ class TestVisionModelInference:
|
|||
"remote::fireworks",
|
||||
"remote::ollama",
|
||||
"remote::vllm",
|
||||
"remote::sambanova",
|
||||
):
|
||||
pytest.skip(
|
||||
"Other inference providers don't support vision chat completion() yet"
|
||||
|
@ -83,6 +84,7 @@ class TestVisionModelInference:
|
|||
"remote::fireworks",
|
||||
"remote::ollama",
|
||||
"remote::vllm",
|
||||
"remote::sambanova",
|
||||
):
|
||||
pytest.skip(
|
||||
"Other inference providers don't support vision chat completion() yet"
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue