From 7ece0d4d8b00bcb981713aed158814b8f5d5a28f Mon Sep 17 00:00:00 2001 From: siddharthsambharia-portkey Date: Fri, 20 Dec 2024 17:21:54 +0530 Subject: [PATCH] portkey integration v1 --- distributions/portkey/build.yaml | 17 ++ distributions/portkey/compose.yaml | 0 distributions/portkey/run.yaml | 77 +++++++ .../remote/inference/portkey/__init__.py | 16 ++ .../remote/inference/portkey/config.py | 32 +++ .../remote/inference/portkey/portkey.py | 190 ++++++++++++++++++ 6 files changed, 332 insertions(+) create mode 100644 distributions/portkey/build.yaml create mode 100644 distributions/portkey/compose.yaml create mode 100644 distributions/portkey/run.yaml create mode 100644 llama_stack/providers/remote/inference/portkey/__init__.py create mode 100644 llama_stack/providers/remote/inference/portkey/config.py create mode 100644 llama_stack/providers/remote/inference/portkey/portkey.py diff --git a/distributions/portkey/build.yaml b/distributions/portkey/build.yaml new file mode 100644 index 000000000..d173be6ff --- /dev/null +++ b/distributions/portkey/build.yaml @@ -0,0 +1,17 @@ +version: '2' +name: portkey +distribution_spec: + description: Use Portkey for running LLM inference + docker_image: null + providers: + inference: + - remote::portkey + safety: + - inline::llama-guard + memory: + - inline::meta-reference + agents: + - inline::meta-reference + telemetry: + - inline::meta-reference +image_type: conda diff --git a/distributions/portkey/compose.yaml b/distributions/portkey/compose.yaml new file mode 100644 index 000000000..e69de29bb diff --git a/distributions/portkey/run.yaml b/distributions/portkey/run.yaml new file mode 100644 index 000000000..b7ac72d6b --- /dev/null +++ b/distributions/portkey/run.yaml @@ -0,0 +1,77 @@ +version: '2' +image_name: portkey +docker_image: null +conda_env: portkey +apis: +- agents +- inference +- memory +- safety +- telemetry +providers: + inference: + - provider_id: portkey + provider_type: remote::portkey + config: + base_url: https://api.portkey.ai + api_key: ${env.PORTKEY_API_KEY} + - provider_id: sentence-transformers + provider_type: inline::sentence-transformers + config: {} + safety: + - provider_id: llama-guard + provider_type: inline::llama-guard + config: {} + memory: + - provider_id: meta-reference + provider_type: inline::meta-reference + config: + kvstore: + type: sqlite + namespace: null + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/portkey}/faiss_store.db + agents: + - provider_id: meta-reference + provider_type: inline::meta-reference + config: + persistence_store: + type: sqlite + namespace: null + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/portkey}/agents_store.db + telemetry: + - provider_id: meta-reference + provider_type: inline::meta-reference + config: + service_name: ${env.OTEL_SERVICE_NAME:llama-stack} + sinks: ${env.TELEMETRY_SINKS:console,sqlite} + sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/portkey/trace_store.db} +metadata_store: + namespace: null + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/portkey}/registry.db +models: +- metadata: {} + model_id: meta-llama/Llama-3.1-8B-Instruct + provider_id: portkey + provider_model_id: llama3.1-8b + model_type: llm +- metadata: {} + model_id: meta-llama/Llama-3.3-70B-Instruct + provider_id: portkey + provider_model_id: llama-3.3-70b + model_type: llm +- metadata: + embedding_dimension: 384 + model_id: all-MiniLM-L6-v2 + provider_id: sentence-transformers + provider_model_id: null + model_type: embedding +shields: +- params: null + shield_id: meta-llama/Llama-Guard-3-8B + provider_id: null + provider_shield_id: null +memory_banks: [] +datasets: [] +scoring_fns: [] +eval_tasks: [] diff --git a/llama_stack/providers/remote/inference/portkey/__init__.py b/llama_stack/providers/remote/inference/portkey/__init__.py new file mode 100644 index 000000000..aaabc055a --- /dev/null +++ b/llama_stack/providers/remote/inference/portkey/__init__.py @@ -0,0 +1,16 @@ + +from .config import PortkeyImplConfig + + +async def get_adapter_impl(config: PortkeyImplConfig, _deps): + from .portkey import PortkeyInferenceAdapter + + assert isinstance( + config, PortkeyImplConfig + ), f"Unexpected config type: {type(config)}" + + impl = PortkeyInferenceAdapter(config) + + await impl.initialize() + + return impl diff --git a/llama_stack/providers/remote/inference/portkey/config.py b/llama_stack/providers/remote/inference/portkey/config.py new file mode 100644 index 000000000..144fbf6a4 --- /dev/null +++ b/llama_stack/providers/remote/inference/portkey/config.py @@ -0,0 +1,32 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import os +from typing import Any, Dict, Optional + +from llama_models.schema_utils import json_schema_type +from pydantic import BaseModel, Field + +DEFAULT_BASE_URL = "https://api.portkey.ai/v1" + + +@json_schema_type +class PortkeyImplConfig(BaseModel): + base_url: str = Field( + default=os.environ.get("PORTKEY_BASE_URL", DEFAULT_BASE_URL), + description="Base URL for the Portkey API", + ) + api_key: Optional[str] = Field( + default=os.environ.get("PORTKEY_API_KEY"), + description="Portkey API Key", + ) + + @classmethod + def sample_run_config(cls, **kwargs) -> Dict[str, Any]: + return { + "base_url": DEFAULT_BASE_URL, + "api_key": "${env.PORTKEY_API_KEY}", + } diff --git a/llama_stack/providers/remote/inference/portkey/portkey.py b/llama_stack/providers/remote/inference/portkey/portkey.py new file mode 100644 index 000000000..c8ed5c4c9 --- /dev/null +++ b/llama_stack/providers/remote/inference/portkey/portkey.py @@ -0,0 +1,190 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import AsyncGenerator + +from portkey_ai import AsyncPortkey + +from llama_models.llama3.api.chat_format import ChatFormat + +from llama_models.llama3.api.tokenizer import Tokenizer + +from llama_stack.apis.inference import * # noqa: F403 + +from llama_models.datatypes import CoreModelId + +from llama_stack.providers.utils.inference.model_registry import ( + build_model_alias, + ModelRegistryHelper, +) +from llama_stack.providers.utils.inference.openai_compat import ( + get_sampling_options, + process_chat_completion_response, + process_chat_completion_stream_response, + process_completion_response, + process_completion_stream_response, +) +from llama_stack.providers.utils.inference.prompt_adapter import ( + chat_completion_request_to_prompt, + completion_request_to_prompt, +) + +from .config import PortkeyImplConfig + + +model_aliases = [ + build_model_alias( + "llama3.1-8b", + CoreModelId.llama3_1_8b_instruct.value, + ), + build_model_alias( + "llama-3.3-70b", + CoreModelId.llama3_3_70b_instruct.value, + ), +] + + +class PortkeyInferenceAdapter(ModelRegistryHelper, Inference): + def __init__(self, config: PortkeyImplConfig) -> None: + ModelRegistryHelper.__init__( + self, + model_aliases=model_aliases, + ) + self.config = config + self.formatter = ChatFormat(Tokenizer.get_instance()) + + self.client = AsyncPortkey( + base_url=self.config.base_url, api_key=self.config.api_key + ) + + async def initialize(self) -> None: + return + + async def shutdown(self) -> None: + pass + + async def completion( + self, + model_id: str, + content: InterleavedContent, + sampling_params: Optional[SamplingParams] = SamplingParams(), + response_format: Optional[ResponseFormat] = None, + stream: Optional[bool] = False, + logprobs: Optional[LogProbConfig] = None, + ) -> AsyncGenerator: + model = await self.model_store.get_model(model_id) + request = CompletionRequest( + model=model.provider_resource_id, + content=content, + sampling_params=sampling_params, + response_format=response_format, + stream=stream, + logprobs=logprobs, + ) + if stream: + return self._stream_completion( + request, + ) + else: + return await self._nonstream_completion(request) + + async def _nonstream_completion( + self, request: CompletionRequest + ) -> CompletionResponse: + params = await self._get_params(request) + + r = await self.client.completions.create(**params) + + return process_completion_response(r, self.formatter) + + async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator: + params = await self._get_params(request) + + stream = await self.client.completions.create(**params) + + async for chunk in process_completion_stream_response(stream, self.formatter): + yield chunk + + async def chat_completion( + self, + model_id: str, + messages: List[Message], + sampling_params: Optional[SamplingParams] = SamplingParams(), + tools: Optional[List[ToolDefinition]] = None, + tool_choice: Optional[ToolChoice] = ToolChoice.auto, + tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, + response_format: Optional[ResponseFormat] = None, + stream: Optional[bool] = False, + logprobs: Optional[LogProbConfig] = None, + ) -> AsyncGenerator: + model = await self.model_store.get_model(model_id) + request = ChatCompletionRequest( + model=model.provider_resource_id, + messages=messages, + sampling_params=sampling_params, + tools=tools or [], + tool_choice=tool_choice, + tool_prompt_format=tool_prompt_format, + response_format=response_format, + stream=stream, + logprobs=logprobs, + ) + + if stream: + return self._stream_chat_completion(request) + else: + return await self._nonstream_chat_completion(request) + + async def _nonstream_chat_completion( + self, request: CompletionRequest + ) -> CompletionResponse: + params = await self._get_params(request) + + r = await self.client.completions.create(**params) + + return process_chat_completion_response(r, self.formatter) + + async def _stream_chat_completion( + self, request: CompletionRequest + ) -> AsyncGenerator: + params = await self._get_params(request) + + stream = await self.client.completions.create(**params) + + async for chunk in process_chat_completion_stream_response( + stream, self.formatter + ): + yield chunk + + async def _get_params( + self, request: Union[ChatCompletionRequest, CompletionRequest] + ) -> dict: + if request.sampling_params and request.sampling_params.top_k: + raise ValueError("`top_k` not supported by Portkey") + + prompt = "" + if isinstance(request, ChatCompletionRequest): + prompt = await chat_completion_request_to_prompt( + request, self.get_llama_model(request.model), self.formatter + ) + elif isinstance(request, CompletionRequest): + prompt = await completion_request_to_prompt(request, self.formatter) + else: + raise ValueError(f"Unknown request type {type(request)}") + + return { + "model": request.model, + "prompt": prompt, + "stream": request.stream, + **get_sampling_options(request.sampling_params), + } + + async def embeddings( + self, + model_id: str, + contents: List[InterleavedContent], + ) -> EmbeddingsResponse: + raise NotImplementedError()