mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-30 15:49:40 +00:00
add rp provider
Signed-off-by: pandyamarut <pandyamarut@gmail.com>
This commit is contained in:
parent
22a1506d65
commit
30a753d80a
2 changed files with 12 additions and 3 deletions
|
@ -25,7 +25,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
|
|
||||||
from .config import RunpodImplConfig
|
from .config import RunpodImplConfig
|
||||||
|
|
||||||
VLLM_SUPPORTED_MODELS = {
|
RUNPOD_SUPPORTED_MODELS = {
|
||||||
"Llama3.1-8B": "meta-llama/Llama-3.1-8B",
|
"Llama3.1-8B": "meta-llama/Llama-3.1-8B",
|
||||||
"Llama3.1-70B": "meta-llama/Llama-3.1-70B",
|
"Llama3.1-70B": "meta-llama/Llama-3.1-70B",
|
||||||
"Llama3.1-405B:bf16-mp8": "meta-llama/Llama-3.1-405B",
|
"Llama3.1-405B:bf16-mp8": "meta-llama/Llama-3.1-405B",
|
||||||
|
@ -64,7 +64,7 @@ class RunpodInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
self.client = OpenAI(base_url=self.config.url, api_key=self.config.api_token)
|
self.client = OpenAI(base_url=self.config.url, api_key=self.config.api_token)
|
||||||
|
|
||||||
async def register_model(self, model: ModelDef) -> None:
|
async def register_model(self, model: ModelDef) -> None:
|
||||||
raise ValueError("Model registration is not supported for vLLM models")
|
raise ValueError("Model registration is not supported for Runpod models")
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
@ -140,7 +140,7 @@ class RunpodInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
|
|
||||||
def _get_params(self, request: ChatCompletionRequest) -> dict:
|
def _get_params(self, request: ChatCompletionRequest) -> dict:
|
||||||
return {
|
return {
|
||||||
"model": VLLM_SUPPORTED_MODELS[request.model],
|
"model": RUNPOD_SUPPORTED_MODELS[request.model],
|
||||||
"prompt": chat_completion_request_to_prompt(request, self.formatter),
|
"prompt": chat_completion_request_to_prompt(request, self.formatter),
|
||||||
"stream": request.stream,
|
"stream": request.stream,
|
||||||
**get_sampling_options(request.sampling_params),
|
**get_sampling_options(request.sampling_params),
|
||||||
|
|
|
@ -149,4 +149,13 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
module="llama_stack.providers.impls.vllm",
|
module="llama_stack.providers.impls.vllm",
|
||||||
config_class="llama_stack.providers.impls.vllm.VLLMConfig",
|
config_class="llama_stack.providers.impls.vllm.VLLMConfig",
|
||||||
),
|
),
|
||||||
|
remote_provider_spec(
|
||||||
|
api=Api.inference,
|
||||||
|
adapter=AdapterSpec(
|
||||||
|
adapter_type="runpod",
|
||||||
|
pip_packages=["openai"],
|
||||||
|
module="llama_stack.providers.adapters.inference.runpod",
|
||||||
|
config_class="llama_stack.providers.adapters.inference.runpod.RunpodImplConfig",
|
||||||
|
),
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue