diff --git a/llama_stack/providers/adapters/inference/runpod/runpod.py b/llama_stack/providers/adapters/inference/runpod/runpod.py index 6c8bd63da..a6255dfe3 100644 --- a/llama_stack/providers/adapters/inference/runpod/runpod.py +++ b/llama_stack/providers/adapters/inference/runpod/runpod.py @@ -25,7 +25,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import ( from .config import RunpodImplConfig -VLLM_SUPPORTED_MODELS = { +RUNPOD_SUPPORTED_MODELS = { "Llama3.1-8B": "meta-llama/Llama-3.1-8B", "Llama3.1-70B": "meta-llama/Llama-3.1-70B", "Llama3.1-405B:bf16-mp8": "meta-llama/Llama-3.1-405B", @@ -64,7 +64,7 @@ class RunpodInferenceAdapter(Inference, ModelsProtocolPrivate): self.client = OpenAI(base_url=self.config.url, api_key=self.config.api_token) async def register_model(self, model: ModelDef) -> None: - raise ValueError("Model registration is not supported for vLLM models") + raise ValueError("Model registration is not supported for Runpod models") async def shutdown(self) -> None: pass @@ -140,7 +140,7 @@ class RunpodInferenceAdapter(Inference, ModelsProtocolPrivate): def _get_params(self, request: ChatCompletionRequest) -> dict: return { - "model": VLLM_SUPPORTED_MODELS[request.model], + "model": RUNPOD_SUPPORTED_MODELS[request.model], "prompt": chat_completion_request_to_prompt(request, self.formatter), "stream": request.stream, **get_sampling_options(request.sampling_params), diff --git a/llama_stack/providers/registry/inference.py b/llama_stack/providers/registry/inference.py index 88265f1b4..b78425e7f 100644 --- a/llama_stack/providers/registry/inference.py +++ b/llama_stack/providers/registry/inference.py @@ -149,4 +149,13 @@ def available_providers() -> List[ProviderSpec]: module="llama_stack.providers.impls.vllm", config_class="llama_stack.providers.impls.vllm.VLLMConfig", ), + remote_provider_spec( + api=Api.inference, + adapter=AdapterSpec( + adapter_type="runpod", + pip_packages=["openai"], + module="llama_stack.providers.adapters.inference.runpod", + config_class="llama_stack.providers.adapters.inference.runpod.RunpodImplConfig", + ), + ), ]