From d78027f3b5961fca4fe407b47dedaf7a5d364365 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Thu, 23 Jan 2025 12:25:12 -0800 Subject: [PATCH] Move runpod provider to the correct directory Also cleanup the test code to avoid skipping tests. Let failures be known and public. --- llama_stack/providers/registry/inference.py | 4 +- .../inference/runpod/__init__.py | 2 +- .../inference/runpod/config.py | 0 .../inference/runpod/runpod.py | 5 +- .../tests/inference/test_text_inference.py | 71 ------------------- .../tests/inference/test_vision_inference.py | 27 ------- 6 files changed, 7 insertions(+), 102 deletions(-) rename llama_stack/providers/{adapters => remote}/inference/runpod/__init__.py (93%) rename llama_stack/providers/{adapters => remote}/inference/runpod/config.py (100%) rename llama_stack/providers/{adapters => remote}/inference/runpod/runpod.py (99%) diff --git a/llama_stack/providers/registry/inference.py b/llama_stack/providers/registry/inference.py index af2cb8e65..e72140ccf 100644 --- a/llama_stack/providers/registry/inference.py +++ b/llama_stack/providers/registry/inference.py @@ -200,8 +200,8 @@ def available_providers() -> List[ProviderSpec]: adapter=AdapterSpec( adapter_type="runpod", pip_packages=["openai"], - module="llama_stack.providers.adapters.inference.runpod", - config_class="llama_stack.providers.adapters.inference.runpod.RunpodImplConfig", + module="llama_stack.providers.remote.inference.runpod", + config_class="llama_stack.providers.remote.inference.runpod.RunpodImplConfig", ), ), remote_provider_spec( diff --git a/llama_stack/providers/adapters/inference/runpod/__init__.py b/llama_stack/providers/remote/inference/runpod/__init__.py similarity index 93% rename from llama_stack/providers/adapters/inference/runpod/__init__.py rename to llama_stack/providers/remote/inference/runpod/__init__.py index 67d49bc45..37432dbb4 100644 --- a/llama_stack/providers/adapters/inference/runpod/__init__.py +++ b/llama_stack/providers/remote/inference/runpod/__init__.py @@ -10,7 +10,7 @@ from .runpod import RunpodInferenceAdapter async def get_adapter_impl(config: RunpodImplConfig, _deps): assert isinstance( - config, RunpodImplConfig + config, RunpodImplConfig ), f"Unexpected config type: {type(config)}" impl = RunpodInferenceAdapter(config) await impl.initialize() diff --git a/llama_stack/providers/adapters/inference/runpod/config.py b/llama_stack/providers/remote/inference/runpod/config.py similarity index 100% rename from llama_stack/providers/adapters/inference/runpod/config.py rename to llama_stack/providers/remote/inference/runpod/config.py diff --git a/llama_stack/providers/adapters/inference/runpod/runpod.py b/llama_stack/providers/remote/inference/runpod/runpod.py similarity index 99% rename from llama_stack/providers/adapters/inference/runpod/runpod.py rename to llama_stack/providers/remote/inference/runpod/runpod.py index cb2e6b237..e5b19426f 100644 --- a/llama_stack/providers/adapters/inference/runpod/runpod.py +++ b/llama_stack/providers/remote/inference/runpod/runpod.py @@ -12,6 +12,7 @@ from llama_models.llama3.api.tokenizer import Tokenizer from openai import OpenAI from llama_stack.apis.inference import * # noqa: F403 + # from llama_stack.providers.datatypes import ModelsProtocolPrivate from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper @@ -40,6 +41,8 @@ RUNPOD_SUPPORTED_MODELS = { "Llama3.2-1B": "meta-llama/Llama-3.2-1B", "Llama3.2-3B": "meta-llama/Llama-3.2-3B", } + + class RunpodInferenceAdapter(ModelRegistryHelper, Inference): def __init__(self, config: RunpodImplConfig) -> None: ModelRegistryHelper.__init__( @@ -130,4 +133,4 @@ class RunpodInferenceAdapter(ModelRegistryHelper, Inference): model: str, contents: List[InterleavedTextMedia], ) -> EmbeddingsResponse: - raise NotImplementedError() \ No newline at end of file + raise NotImplementedError() diff --git a/llama_stack/providers/tests/inference/test_text_inference.py b/llama_stack/providers/tests/inference/test_text_inference.py index 7201fdc4a..5f1a429a1 100644 --- a/llama_stack/providers/tests/inference/test_text_inference.py +++ b/llama_stack/providers/tests/inference/test_text_inference.py @@ -109,19 +109,6 @@ class TestInference: async def test_completion(self, inference_model, inference_stack): inference_impl, _ = inference_stack - provider = inference_impl.routing_table.get_provider_impl(inference_model) - if provider.__provider_spec__.provider_type not in ( - "inline::meta-reference", - "remote::ollama", - "remote::tgi", - "remote::together", - "remote::fireworks", - "remote::nvidia", - "remote::cerebras", - "remote::vllm", - ): - pytest.skip("Other inference providers don't support completion() yet") - response = await inference_impl.completion( content="Micheael Jordan is born in ", stream=False, @@ -155,12 +142,6 @@ class TestInference: async def test_completion_logprobs(self, inference_model, inference_stack): inference_impl, _ = inference_stack - provider = inference_impl.routing_table.get_provider_impl(inference_model) - if provider.__provider_spec__.provider_type not in ( - # "remote::nvidia", -- provider doesn't provide all logprobs - ): - pytest.skip("Other inference providers don't support completion() yet") - response = await inference_impl.completion( content="Micheael Jordan is born in ", stream=False, @@ -212,21 +193,6 @@ class TestInference: async def test_completion_structured_output(self, inference_model, inference_stack): inference_impl, _ = inference_stack - provider = inference_impl.routing_table.get_provider_impl(inference_model) - if provider.__provider_spec__.provider_type not in ( - "inline::meta-reference", - "remote::ollama", - "remote::tgi", - "remote::together", - "remote::fireworks", - "remote::nvidia", - "remote::vllm", - "remote::cerebras", - ): - pytest.skip( - "Other inference providers don't support structured output in completions yet" - ) - class Output(BaseModel): name: str year_born: str @@ -275,18 +241,6 @@ class TestInference: ): inference_impl, _ = inference_stack - provider = inference_impl.routing_table.get_provider_impl(inference_model) - if provider.__provider_spec__.provider_type not in ( - "inline::meta-reference", - "remote::ollama", - "remote::fireworks", - "remote::tgi", - "remote::together", - "remote::vllm", - "remote::nvidia", - ): - pytest.skip("Other inference providers don't support structured output yet") - class AnswerFormat(BaseModel): first_name: str last_name: str @@ -377,20 +331,6 @@ class TestInference: sample_tool_definition, ): inference_impl, _ = inference_stack - provider = inference_impl.routing_table.get_provider_impl(inference_model) - if ( - provider.__provider_spec__.provider_type == "remote::groq" - and "Llama-3.2" in inference_model - ): - # TODO(aidand): Remove this skip once Groq's tool calling for Llama3.2 works better - pytest.skip("Groq's tool calling for Llama3.2 doesn't work very well") - - if provider.__provider_spec__.provider_type == "remote::sambanova" and ( - "-1B-" in inference_model or "-3B-" in inference_model - ): - # TODO(snova-edawrdm): Remove this skip once SambaNova's tool calling for 1B/ 3B - pytest.skip("Sambanova's tool calling for lightweight models don't work") - messages = sample_messages + [ UserMessage( content="What's the weather like in San Francisco?", @@ -430,17 +370,6 @@ class TestInference: sample_tool_definition, ): inference_impl, _ = inference_stack - provider = inference_impl.routing_table.get_provider_impl(inference_model) - if ( - provider.__provider_spec__.provider_type == "remote::groq" - and "Llama-3.2" in inference_model - ): - # TODO(aidand): Remove this skip once Groq's tool calling for Llama3.2 works better - pytest.skip("Groq's tool calling for Llama3.2 doesn't work very well") - if provider.__provider_spec__.provider_type == "remote::sambanova": - # TODO(snova-edawrdm): Remove this skip once SambaNova's tool calling under streaming is supported (we are working on it) - pytest.skip("Sambanova's tool calling for streaming doesn't work") - messages = sample_messages + [ UserMessage( content="What's the weather like in San Francisco?", diff --git a/llama_stack/providers/tests/inference/test_vision_inference.py b/llama_stack/providers/tests/inference/test_vision_inference.py index fba7cefde..a06c4a7d5 100644 --- a/llama_stack/providers/tests/inference/test_vision_inference.py +++ b/llama_stack/providers/tests/inference/test_vision_inference.py @@ -51,20 +51,6 @@ class TestVisionModelInference: self, inference_model, inference_stack, image, expected_strings ): inference_impl, _ = inference_stack - - provider = inference_impl.routing_table.get_provider_impl(inference_model) - if provider.__provider_spec__.provider_type not in ( - "inline::meta-reference", - "remote::together", - "remote::fireworks", - "remote::ollama", - "remote::vllm", - "remote::sambanova", - ): - pytest.skip( - "Other inference providers don't support vision chat completion() yet" - ) - response = await inference_impl.chat_completion( model_id=inference_model, messages=[ @@ -92,19 +78,6 @@ class TestVisionModelInference: ): inference_impl, _ = inference_stack - provider = inference_impl.routing_table.get_provider_impl(inference_model) - if provider.__provider_spec__.provider_type not in ( - "inline::meta-reference", - "remote::together", - "remote::fireworks", - "remote::ollama", - "remote::vllm", - "remote::sambanova", - ): - pytest.skip( - "Other inference providers don't support vision chat completion() yet" - ) - images = [ ImageContentItem( image=dict(