diff --git a/llama_stack/distribution/library_client.py b/llama_stack/distribution/library_client.py index 59189f8bb..95c7759b8 100644 --- a/llama_stack/distribution/library_client.py +++ b/llama_stack/distribution/library_client.py @@ -324,6 +324,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): await end_trace() json_content = json.dumps(convert_pydantic_to_json_value(result)) + mock_response = httpx.Response( status_code=httpx.codes.OK, content=json_content.encode("utf-8"), @@ -335,7 +336,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): url=options.url, params=options.params, headers=options.headers or {}, - json=options.json_data, + json=convert_pydantic_to_json_value(body), ), ) response = APIResponse( @@ -384,7 +385,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): url=options.url, params=options.params, headers=options.headers or {}, - json=options.json_data, + json=convert_pydantic_to_json_value(body), ), ) diff --git a/llama_stack/providers/utils/inference/openai_compat.py b/llama_stack/providers/utils/inference/openai_compat.py index 1f1306f0d..1309e72a6 100644 --- a/llama_stack/providers/utils/inference/openai_compat.py +++ b/llama_stack/providers/utils/inference/openai_compat.py @@ -518,40 +518,44 @@ async def convert_message_to_openai_dict_new(message: Message | Dict) -> OpenAIC # {"type": "image", "image": {"url": {"uri": ...}}} -> {"type": "image_url", "image_url": {"url": ...}} # {"type": "image", "image": {"data": ...}} -> {"type": "image_url", "image_url": {"url": "data:image/?;base64,..."}} # List[...] -> List[...] - async def _convert_user_message_content( + async def _convert_message_content( content: InterleavedContent, ) -> Union[str, Iterable[OpenAIChatCompletionContentPartParam]]: - # Llama Stack and OpenAI spec match for str and text input - if isinstance(content, str): - return OpenAIChatCompletionContentPartTextParam( - type="text", - text=content, - ) - elif isinstance(content, TextContentItem): - return OpenAIChatCompletionContentPartTextParam( - type="text", - text=content.text, - ) - elif isinstance(content, ImageContentItem): - return OpenAIChatCompletionContentPartImageParam( - type="image_url", - image_url=OpenAIImageURL(url=await convert_image_content_to_url(content)), - ) - elif isinstance(content, List): - return [await _convert_user_message_content(item) for item in content] + async def impl(): + # Llama Stack and OpenAI spec match for str and text input + if isinstance(content, str): + return content + elif isinstance(content, TextContentItem): + return OpenAIChatCompletionContentPartTextParam( + type="text", + text=content.text, + ) + elif isinstance(content, ImageContentItem): + return OpenAIChatCompletionContentPartImageParam( + type="image_url", + image_url=OpenAIImageURL(url=await convert_image_content_to_url(content)), + ) + elif isinstance(content, list): + return [await _convert_message_content(item) for item in content] + else: + raise ValueError(f"Unsupported content type: {type(content)}") + + ret = await impl() + if isinstance(ret, str) or isinstance(ret, list): + return ret else: - raise ValueError(f"Unsupported content type: {type(content)}") + return [ret] out: OpenAIChatCompletionMessage = None if isinstance(message, UserMessage): out = OpenAIChatCompletionUserMessage( role="user", - content=await _convert_user_message_content(message.content), + content=await _convert_message_content(message.content), ) elif isinstance(message, CompletionMessage): out = OpenAIChatCompletionAssistantMessage( role="assistant", - content=message.content, + content=await _convert_message_content(message.content), tool_calls=[ OpenAIChatCompletionMessageToolCall( id=tool.call_id, @@ -568,12 +572,12 @@ async def convert_message_to_openai_dict_new(message: Message | Dict) -> OpenAIC out = OpenAIChatCompletionToolMessage( role="tool", tool_call_id=message.call_id, - content=message.content, + content=await _convert_message_content(message.content), ) elif isinstance(message, SystemMessage): out = OpenAIChatCompletionSystemMessage( role="system", - content=message.content, + content=await _convert_message_content(message.content), ) else: raise ValueError(f"Unsupported message type: {type(message)}") diff --git a/tests/client-sdk/inference/test_text_inference.py b/tests/client-sdk/inference/test_text_inference.py index 53afcaa4a..2fd068efc 100644 --- a/tests/client-sdk/inference/test_text_inference.py +++ b/tests/client-sdk/inference/test_text_inference.py @@ -7,14 +7,9 @@ import pytest from pydantic import BaseModel +from llama_stack.models.llama.sku_list import resolve_model from llama_stack.providers.tests.test_cases.test_case import TestCase - -PROVIDER_TOOL_PROMPT_FORMAT = { - "remote::ollama": "json", - "remote::together": "json", - "remote::fireworks": "json", - "remote::vllm": "json", -} +from llama_stack.providers.utils.inference.prompt_adapter import get_default_tool_prompt_format PROVIDER_LOGPROBS_TOP_K = {"remote::together", "remote::fireworks", "remote::vllm"} @@ -24,18 +19,32 @@ def skip_if_model_doesnt_support_completion(client_with_models, model_id): provider_id = models[model_id].provider_id providers = {p.provider_id: p for p in client_with_models.providers.list()} provider = providers[provider_id] - print(f"Provider: {provider.provider_type} for model {model_id}") if provider.provider_type in ("remote::openai", "remote::anthropic", "remote::gemini"): pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support completion") -@pytest.fixture(scope="session") -def provider_tool_format(inference_provider_type): - return ( - PROVIDER_TOOL_PROMPT_FORMAT[inference_provider_type] - if inference_provider_type in PROVIDER_TOOL_PROMPT_FORMAT - else None - ) +def get_llama_model(client_with_models, model_id): + models = {} + for m in client_with_models.models.list(): + models[m.identifier] = m + models[m.provider_resource_id] = m + + assert model_id in models, f"Model {model_id} not found" + + model = models[model_id] + ids = (model.identifier, model.provider_resource_id) + for mid in ids: + if resolve_model(mid): + return mid + + return model.metadata.get("llama_model", None) + + +def get_tool_prompt_format(client_with_models, model_id): + llama_model = get_llama_model(client_with_models, model_id) + if not llama_model: + return None + return get_default_tool_prompt_format(llama_model) @pytest.mark.parametrize( @@ -237,12 +246,8 @@ def test_text_chat_completion_streaming(client_with_models, text_model_id, test_ "inference:chat_completion:tool_calling", ], ) -def test_text_chat_completion_with_tool_calling_and_non_streaming( - client_with_models, text_model_id, provider_tool_format, test_case -): - # TODO: more dynamic lookup on tool_prompt_format for model family - tool_prompt_format = "json" if "3.1" in text_model_id else "python_list" - +def test_text_chat_completion_with_tool_calling_and_non_streaming(client_with_models, text_model_id, test_case): + tool_prompt_format = get_tool_prompt_format(client_with_models, text_model_id) tc = TestCase(test_case) response = client_with_models.inference.chat_completion( @@ -280,12 +285,8 @@ def extract_tool_invocation_content(response): "inference:chat_completion:tool_calling", ], ) -def test_text_chat_completion_with_tool_calling_and_streaming( - client_with_models, text_model_id, provider_tool_format, test_case -): - # TODO: more dynamic lookup on tool_prompt_format for model family - tool_prompt_format = "json" if "3.1" in text_model_id else "python_list" - +def test_text_chat_completion_with_tool_calling_and_streaming(client_with_models, text_model_id, test_case): + tool_prompt_format = get_tool_prompt_format(client_with_models, text_model_id) tc = TestCase(test_case) response = client_with_models.inference.chat_completion( @@ -308,14 +309,8 @@ def test_text_chat_completion_with_tool_calling_and_streaming( "inference:chat_completion:tool_calling", ], ) -def test_text_chat_completion_with_tool_choice_required( - client_with_models, - text_model_id, - provider_tool_format, - test_case, -): - # TODO: more dynamic lookup on tool_prompt_format for model family - tool_prompt_format = "json" if "3.1" in text_model_id else "python_list" +def test_text_chat_completion_with_tool_choice_required(client_with_models, text_model_id, test_case): + tool_prompt_format = get_tool_prompt_format(client_with_models, text_model_id) tc = TestCase(test_case) @@ -341,14 +336,15 @@ def test_text_chat_completion_with_tool_choice_required( "inference:chat_completion:tool_calling", ], ) -def test_text_chat_completion_with_tool_choice_none(client_with_models, text_model_id, provider_tool_format, test_case): +def test_text_chat_completion_with_tool_choice_none(client_with_models, text_model_id, test_case): + tool_prompt_format = get_tool_prompt_format(client_with_models, text_model_id) tc = TestCase(test_case) response = client_with_models.inference.chat_completion( model_id=text_model_id, messages=tc["messages"], tools=tc["tools"], - tool_config={"tool_choice": "none", "tool_prompt_format": provider_tool_format}, + tool_config={"tool_choice": "none", "tool_prompt_format": tool_prompt_format}, stream=True, ) tool_invocation_content = extract_tool_invocation_content(response)