diff --git a/llama_stack/distribution/library_client.py b/llama_stack/distribution/library_client.py index 59189f8bb..95c7759b8 100644 --- a/llama_stack/distribution/library_client.py +++ b/llama_stack/distribution/library_client.py @@ -324,6 +324,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): await end_trace() json_content = json.dumps(convert_pydantic_to_json_value(result)) + mock_response = httpx.Response( status_code=httpx.codes.OK, content=json_content.encode("utf-8"), @@ -335,7 +336,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): url=options.url, params=options.params, headers=options.headers or {}, - json=options.json_data, + json=convert_pydantic_to_json_value(body), ), ) response = APIResponse( @@ -384,7 +385,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): url=options.url, params=options.params, headers=options.headers or {}, - json=options.json_data, + json=convert_pydantic_to_json_value(body), ), ) diff --git a/tests/client-sdk/inference/test_text_inference.py b/tests/client-sdk/inference/test_text_inference.py index 53afcaa4a..2fd068efc 100644 --- a/tests/client-sdk/inference/test_text_inference.py +++ b/tests/client-sdk/inference/test_text_inference.py @@ -7,14 +7,9 @@ import pytest from pydantic import BaseModel +from llama_stack.models.llama.sku_list import resolve_model from llama_stack.providers.tests.test_cases.test_case import TestCase - -PROVIDER_TOOL_PROMPT_FORMAT = { - "remote::ollama": "json", - "remote::together": "json", - "remote::fireworks": "json", - "remote::vllm": "json", -} +from llama_stack.providers.utils.inference.prompt_adapter import get_default_tool_prompt_format PROVIDER_LOGPROBS_TOP_K = {"remote::together", "remote::fireworks", "remote::vllm"} @@ -24,18 +19,32 @@ def skip_if_model_doesnt_support_completion(client_with_models, model_id): provider_id = models[model_id].provider_id providers = {p.provider_id: p for p in client_with_models.providers.list()} provider = providers[provider_id] - print(f"Provider: {provider.provider_type} for model {model_id}") if provider.provider_type in ("remote::openai", "remote::anthropic", "remote::gemini"): pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support completion") -@pytest.fixture(scope="session") -def provider_tool_format(inference_provider_type): - return ( - PROVIDER_TOOL_PROMPT_FORMAT[inference_provider_type] - if inference_provider_type in PROVIDER_TOOL_PROMPT_FORMAT - else None - ) +def get_llama_model(client_with_models, model_id): + models = {} + for m in client_with_models.models.list(): + models[m.identifier] = m + models[m.provider_resource_id] = m + + assert model_id in models, f"Model {model_id} not found" + + model = models[model_id] + ids = (model.identifier, model.provider_resource_id) + for mid in ids: + if resolve_model(mid): + return mid + + return model.metadata.get("llama_model", None) + + +def get_tool_prompt_format(client_with_models, model_id): + llama_model = get_llama_model(client_with_models, model_id) + if not llama_model: + return None + return get_default_tool_prompt_format(llama_model) @pytest.mark.parametrize( @@ -237,12 +246,8 @@ def test_text_chat_completion_streaming(client_with_models, text_model_id, test_ "inference:chat_completion:tool_calling", ], ) -def test_text_chat_completion_with_tool_calling_and_non_streaming( - client_with_models, text_model_id, provider_tool_format, test_case -): - # TODO: more dynamic lookup on tool_prompt_format for model family - tool_prompt_format = "json" if "3.1" in text_model_id else "python_list" - +def test_text_chat_completion_with_tool_calling_and_non_streaming(client_with_models, text_model_id, test_case): + tool_prompt_format = get_tool_prompt_format(client_with_models, text_model_id) tc = TestCase(test_case) response = client_with_models.inference.chat_completion( @@ -280,12 +285,8 @@ def extract_tool_invocation_content(response): "inference:chat_completion:tool_calling", ], ) -def test_text_chat_completion_with_tool_calling_and_streaming( - client_with_models, text_model_id, provider_tool_format, test_case -): - # TODO: more dynamic lookup on tool_prompt_format for model family - tool_prompt_format = "json" if "3.1" in text_model_id else "python_list" - +def test_text_chat_completion_with_tool_calling_and_streaming(client_with_models, text_model_id, test_case): + tool_prompt_format = get_tool_prompt_format(client_with_models, text_model_id) tc = TestCase(test_case) response = client_with_models.inference.chat_completion( @@ -308,14 +309,8 @@ def test_text_chat_completion_with_tool_calling_and_streaming( "inference:chat_completion:tool_calling", ], ) -def test_text_chat_completion_with_tool_choice_required( - client_with_models, - text_model_id, - provider_tool_format, - test_case, -): - # TODO: more dynamic lookup on tool_prompt_format for model family - tool_prompt_format = "json" if "3.1" in text_model_id else "python_list" +def test_text_chat_completion_with_tool_choice_required(client_with_models, text_model_id, test_case): + tool_prompt_format = get_tool_prompt_format(client_with_models, text_model_id) tc = TestCase(test_case) @@ -341,14 +336,15 @@ def test_text_chat_completion_with_tool_choice_required( "inference:chat_completion:tool_calling", ], ) -def test_text_chat_completion_with_tool_choice_none(client_with_models, text_model_id, provider_tool_format, test_case): +def test_text_chat_completion_with_tool_choice_none(client_with_models, text_model_id, test_case): + tool_prompt_format = get_tool_prompt_format(client_with_models, text_model_id) tc = TestCase(test_case) response = client_with_models.inference.chat_completion( model_id=text_model_id, messages=tc["messages"], tools=tc["tools"], - tool_config={"tool_choice": "none", "tool_prompt_format": provider_tool_format}, + tool_config={"tool_choice": "none", "tool_prompt_format": tool_prompt_format}, stream=True, ) tool_invocation_content = extract_tool_invocation_content(response)