diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py index bf55c5ad2..823c7267c 100644 --- a/llama_stack/providers/remote/inference/ollama/ollama.py +++ b/llama_stack/providers/remote/inference/ollama/ollama.py @@ -48,10 +48,10 @@ model_aliases = [ "llama3.1:8b-instruct-fp16", CoreModelId.llama3_1_8b_instruct.value, ), - build_model_alias_with_just_provider_model_id( - "llama3.1:8b", - CoreModelId.llama3_1_8b_instruct.value, - ), + # build_model_alias_with_just_provider_model_id( + # "llama3.1:8b", + # CoreModelId.llama3_1_8b_instruct.value, + # ), build_model_alias( "llama3.1:70b-instruct-fp16", CoreModelId.llama3_1_70b_instruct.value, @@ -214,6 +214,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): tool_prompt_format=tool_prompt_format, stream=stream, logprobs=logprobs, + response_format=response_format, ) if stream: return self._stream_chat_completion(request) @@ -257,6 +258,14 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): ) input_dict["raw"] = True + if fmt := request.response_format: + if fmt.type == "json_schema": + input_dict["format"] = fmt.json_schema + elif fmt.type == "grammar": + raise NotImplementedError("Grammar response format is not supported") + else: + raise ValueError(f"Unknown response format type: {fmt.type}") + return { "model": request.model, **input_dict, diff --git a/llama_stack/providers/tests/inference/test_text_inference.py b/llama_stack/providers/tests/inference/test_text_inference.py index 99a62ac08..3d3d1ad55 100644 --- a/llama_stack/providers/tests/inference/test_text_inference.py +++ b/llama_stack/providers/tests/inference/test_text_inference.py @@ -191,6 +191,7 @@ class TestInference: provider = inference_impl.routing_table.get_provider_impl(inference_model) if provider.__provider_spec__.provider_type not in ( "inline::meta-reference", + "remote::ollama", "remote::tgi", "remote::together", "remote::fireworks", @@ -253,6 +254,7 @@ class TestInference: provider = inference_impl.routing_table.get_provider_impl(inference_model) if provider.__provider_spec__.provider_type not in ( "inline::meta-reference", + "remote::ollama", "remote::fireworks", "remote::tgi", "remote::together",