add rp provider

Signed-off-by: pandyamarut <pandyamarut@gmail.com>
This commit is contained in:
pandyamarut 2024-11-03 19:38:22 -05:00
parent 22a1506d65
commit 30a753d80a
2 changed files with 12 additions and 3 deletions

View file

@ -25,7 +25,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
from .config import RunpodImplConfig
VLLM_SUPPORTED_MODELS = {
RUNPOD_SUPPORTED_MODELS = {
"Llama3.1-8B": "meta-llama/Llama-3.1-8B",
"Llama3.1-70B": "meta-llama/Llama-3.1-70B",
"Llama3.1-405B:bf16-mp8": "meta-llama/Llama-3.1-405B",
@ -64,7 +64,7 @@ class RunpodInferenceAdapter(Inference, ModelsProtocolPrivate):
self.client = OpenAI(base_url=self.config.url, api_key=self.config.api_token)
async def register_model(self, model: ModelDef) -> None:
raise ValueError("Model registration is not supported for vLLM models")
raise ValueError("Model registration is not supported for Runpod models")
async def shutdown(self) -> None:
pass
@ -140,7 +140,7 @@ class RunpodInferenceAdapter(Inference, ModelsProtocolPrivate):
def _get_params(self, request: ChatCompletionRequest) -> dict:
return {
"model": VLLM_SUPPORTED_MODELS[request.model],
"model": RUNPOD_SUPPORTED_MODELS[request.model],
"prompt": chat_completion_request_to_prompt(request, self.formatter),
"stream": request.stream,
**get_sampling_options(request.sampling_params),