From a814755f479c41a0efb0051172db848ffbc60674 Mon Sep 17 00:00:00 2001 From: pandyamarut Date: Sun, 3 Nov 2024 19:53:22 -0500 Subject: [PATCH] add rp provider Signed-off-by: pandyamarut --- .../adapters/inference/runpod/runpod.py | 53 ++++++------------- 1 file changed, 16 insertions(+), 37 deletions(-) diff --git a/llama_stack/providers/adapters/inference/runpod/runpod.py b/llama_stack/providers/adapters/inference/runpod/runpod.py index a6255dfe3..cb2e6b237 100644 --- a/llama_stack/providers/adapters/inference/runpod/runpod.py +++ b/llama_stack/providers/adapters/inference/runpod/runpod.py @@ -12,7 +12,8 @@ from llama_models.llama3.api.tokenizer import Tokenizer from openai import OpenAI from llama_stack.apis.inference import * # noqa: F403 -from llama_stack.providers.datatypes import ModelsProtocolPrivate +# from llama_stack.providers.datatypes import ModelsProtocolPrivate +from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper from llama_stack.providers.utils.inference.openai_compat import ( get_sampling_options, @@ -38,43 +39,21 @@ RUNPOD_SUPPORTED_MODELS = { "Llama3.1-405B-Instruct:bf16-mp16": "meta-llama/Llama-3.1-405B-Instruct", "Llama3.2-1B": "meta-llama/Llama-3.2-1B", "Llama3.2-3B": "meta-llama/Llama-3.2-3B", - "Llama3.2-11B-Vision": "meta-llama/Llama-3.2-11B-Vision", - "Llama3.2-90B-Vision": "meta-llama/Llama-3.2-90B-Vision", - "Llama3.2-1B-Instruct": "meta-llama/Llama-3.2-1B-Instruct", - "Llama3.2-3B-Instruct": "meta-llama/Llama-3.2-3B-Instruct", - "Llama3.2-11B-Vision-Instruct": "meta-llama/Llama-3.2-11B-Vision-Instruct", - "Llama3.2-90B-Vision-Instruct": "meta-llama/Llama-3.2-90B-Vision-Instruct", - "Llama-Guard-3-11B-Vision": "meta-llama/Llama-Guard-3-11B-Vision", - "Llama-Guard-3-1B:int4-mp1": "meta-llama/Llama-Guard-3-1B-INT4", - "Llama-Guard-3-1B": "meta-llama/Llama-Guard-3-1B", - "Llama-Guard-3-8B": "meta-llama/Llama-Guard-3-8B", - "Llama-Guard-3-8B:int8-mp1": "meta-llama/Llama-Guard-3-8B-INT8", - "Prompt-Guard-86M": "meta-llama/Prompt-Guard-86M", - "Llama-Guard-2-8B": "meta-llama/Llama-Guard-2-8B", } - - -class RunpodInferenceAdapter(Inference, ModelsProtocolPrivate): +class RunpodInferenceAdapter(ModelRegistryHelper, Inference): def __init__(self, config: RunpodImplConfig) -> None: + ModelRegistryHelper.__init__( + self, stack_to_provider_models_map=RUNPOD_SUPPORTED_MODELS + ) self.config = config self.formatter = ChatFormat(Tokenizer.get_instance()) - self.client = None async def initialize(self) -> None: - self.client = OpenAI(base_url=self.config.url, api_key=self.config.api_token) - - async def register_model(self, model: ModelDef) -> None: - raise ValueError("Model registration is not supported for Runpod models") + return async def shutdown(self) -> None: pass - async def list_models(self) -> List[ModelDef]: - return [ - ModelDef(identifier=model.id, llama_model=model.id) - for model in self.client.models.list() - ] - async def completion( self, model: str, @@ -83,7 +62,7 @@ class RunpodInferenceAdapter(Inference, ModelsProtocolPrivate): response_format: Optional[ResponseFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, - ) -> Union[CompletionResponse, CompletionResponseStreamChunk]: + ) -> AsyncGenerator: raise NotImplementedError() async def chat_completion( @@ -108,25 +87,25 @@ class RunpodInferenceAdapter(Inference, ModelsProtocolPrivate): stream=stream, logprobs=logprobs, ) + + client = OpenAI(base_url=self.config.url, api_key=self.config.api_token) if stream: - return self._stream_chat_completion(request, self.client) + return self._stream_chat_completion(request, client) else: - return await self._nonstream_chat_completion(request, self.client) + return await self._nonstream_chat_completion(request, client) async def _nonstream_chat_completion( self, request: ChatCompletionRequest, client: OpenAI ) -> ChatCompletionResponse: params = self._get_params(request) r = client.completions.create(**params) - return process_chat_completion_response(request, r, self.formatter) + return process_chat_completion_response(r, self.formatter) async def _stream_chat_completion( self, request: ChatCompletionRequest, client: OpenAI ) -> AsyncGenerator: params = self._get_params(request) - # TODO: Can we use client.completions.acreate() or maybe there is another way to directly create an async - # generator so this wrapper is not necessary? async def _to_async_generator(): s = client.completions.create(**params) for chunk in s: @@ -134,13 +113,13 @@ class RunpodInferenceAdapter(Inference, ModelsProtocolPrivate): stream = _to_async_generator() async for chunk in process_chat_completion_stream_response( - request, stream, self.formatter + stream, self.formatter ): yield chunk def _get_params(self, request: ChatCompletionRequest) -> dict: return { - "model": RUNPOD_SUPPORTED_MODELS[request.model], + "model": self.map_to_provider_model(request.model), "prompt": chat_completion_request_to_prompt(request, self.formatter), "stream": request.stream, **get_sampling_options(request.sampling_params), @@ -151,4 +130,4 @@ class RunpodInferenceAdapter(Inference, ModelsProtocolPrivate): model: str, contents: List[InterleavedTextMedia], ) -> EmbeddingsResponse: - raise NotImplementedError() + raise NotImplementedError() \ No newline at end of file