mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 02:53:30 +00:00
fix(test): update client-sdk tests to handle tool format parametrization better (#1287)
# What does this PR do? Tool format depends on the model. @ehhuang introduced a `get_default_tool_prompt_format` function for this purpose. We should use that instead of hacky model ID matching we had before. Secondly, non llama models don't have this concept so testing with those models should work as is. [//]: # (If resolving an issue, uncomment and update the line below) [//]: # (Closes #[issue-number]) ## Test Plan ```bash for distro in fireworks ollama; do LLAMA_STACK_CONFIG=$distro \ pytest -s -v tests/client-sdk/inference/test_text_inference.py \ --inference-model=meta-llama/Llama-3.2-3B-Instruct \ --vision-inference-model="" done LLAMA_STACK_CONFIG=dev \ pytest -s -v tests/client-sdk/inference/test_text_inference.py \ --inference-model=openai/gpt-4o \ --vision-inference-model="" ``` [//]: # (## Documentation)
This commit is contained in:
parent
30ef1c3680
commit
23b65b6cee
3 changed files with 64 additions and 63 deletions
|
@ -518,40 +518,44 @@ async def convert_message_to_openai_dict_new(message: Message | Dict) -> OpenAIC
|
|||
# {"type": "image", "image": {"url": {"uri": ...}}} -> {"type": "image_url", "image_url": {"url": ...}}
|
||||
# {"type": "image", "image": {"data": ...}} -> {"type": "image_url", "image_url": {"url": "data:image/?;base64,..."}}
|
||||
# List[...] -> List[...]
|
||||
async def _convert_user_message_content(
|
||||
async def _convert_message_content(
|
||||
content: InterleavedContent,
|
||||
) -> Union[str, Iterable[OpenAIChatCompletionContentPartParam]]:
|
||||
# Llama Stack and OpenAI spec match for str and text input
|
||||
if isinstance(content, str):
|
||||
return OpenAIChatCompletionContentPartTextParam(
|
||||
type="text",
|
||||
text=content,
|
||||
)
|
||||
elif isinstance(content, TextContentItem):
|
||||
return OpenAIChatCompletionContentPartTextParam(
|
||||
type="text",
|
||||
text=content.text,
|
||||
)
|
||||
elif isinstance(content, ImageContentItem):
|
||||
return OpenAIChatCompletionContentPartImageParam(
|
||||
type="image_url",
|
||||
image_url=OpenAIImageURL(url=await convert_image_content_to_url(content)),
|
||||
)
|
||||
elif isinstance(content, List):
|
||||
return [await _convert_user_message_content(item) for item in content]
|
||||
async def impl():
|
||||
# Llama Stack and OpenAI spec match for str and text input
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
elif isinstance(content, TextContentItem):
|
||||
return OpenAIChatCompletionContentPartTextParam(
|
||||
type="text",
|
||||
text=content.text,
|
||||
)
|
||||
elif isinstance(content, ImageContentItem):
|
||||
return OpenAIChatCompletionContentPartImageParam(
|
||||
type="image_url",
|
||||
image_url=OpenAIImageURL(url=await convert_image_content_to_url(content)),
|
||||
)
|
||||
elif isinstance(content, list):
|
||||
return [await _convert_message_content(item) for item in content]
|
||||
else:
|
||||
raise ValueError(f"Unsupported content type: {type(content)}")
|
||||
|
||||
ret = await impl()
|
||||
if isinstance(ret, str) or isinstance(ret, list):
|
||||
return ret
|
||||
else:
|
||||
raise ValueError(f"Unsupported content type: {type(content)}")
|
||||
return [ret]
|
||||
|
||||
out: OpenAIChatCompletionMessage = None
|
||||
if isinstance(message, UserMessage):
|
||||
out = OpenAIChatCompletionUserMessage(
|
||||
role="user",
|
||||
content=await _convert_user_message_content(message.content),
|
||||
content=await _convert_message_content(message.content),
|
||||
)
|
||||
elif isinstance(message, CompletionMessage):
|
||||
out = OpenAIChatCompletionAssistantMessage(
|
||||
role="assistant",
|
||||
content=message.content,
|
||||
content=await _convert_message_content(message.content),
|
||||
tool_calls=[
|
||||
OpenAIChatCompletionMessageToolCall(
|
||||
id=tool.call_id,
|
||||
|
@ -568,12 +572,12 @@ async def convert_message_to_openai_dict_new(message: Message | Dict) -> OpenAIC
|
|||
out = OpenAIChatCompletionToolMessage(
|
||||
role="tool",
|
||||
tool_call_id=message.call_id,
|
||||
content=message.content,
|
||||
content=await _convert_message_content(message.content),
|
||||
)
|
||||
elif isinstance(message, SystemMessage):
|
||||
out = OpenAIChatCompletionSystemMessage(
|
||||
role="system",
|
||||
content=message.content,
|
||||
content=await _convert_message_content(message.content),
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported message type: {type(message)}")
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue