From 59d856363eccb5bb7f645397e6725604c47ee31a Mon Sep 17 00:00:00 2001 From: alejandro Date: Thu, 31 Oct 2024 13:00:22 -0400 Subject: [PATCH] feat: initial implementation of snowflake provider + distro --- distributions/snowflake/build.yaml | 9 + .../adapters/inference/snowflake/__init__.py | 17 ++ .../adapters/inference/snowflake/config.py | 21 ++ .../adapters/inference/snowflake/snowflake.py | 237 ++++++++++++++++++ llama_stack/providers/registry/inference.py | 9 + 5 files changed, 293 insertions(+) create mode 100644 distributions/snowflake/build.yaml create mode 100644 llama_stack/providers/adapters/inference/snowflake/__init__.py create mode 100644 llama_stack/providers/adapters/inference/snowflake/config.py create mode 100644 llama_stack/providers/adapters/inference/snowflake/snowflake.py diff --git a/distributions/snowflake/build.yaml b/distributions/snowflake/build.yaml new file mode 100644 index 000000000..ae9ebb4ae --- /dev/null +++ b/distributions/snowflake/build.yaml @@ -0,0 +1,9 @@ +name: snowflake +distribution_spec: + description: Use Snowflake for running LLM inference + providers: + inference: remote::snowflake + memory: meta-reference + safety: meta-reference + agents: meta-reference + telemetry: meta-reference diff --git a/llama_stack/providers/adapters/inference/snowflake/__init__.py b/llama_stack/providers/adapters/inference/snowflake/__init__.py new file mode 100644 index 000000000..9a265f850 --- /dev/null +++ b/llama_stack/providers/adapters/inference/snowflake/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .config import SnowflakeImplConfig +from .snowflake import SnowflakeInferenceAdapter + + +async def get_adapter_impl(config: SnowflakeImplConfig, _deps): + assert isinstance( + config, SnowflakeImplConfig + ), f"Unexpected config type: {type(config)}" + impl = SnowflakeInferenceAdapter(config) + await impl.initialize() + return impl diff --git a/llama_stack/providers/adapters/inference/snowflake/config.py b/llama_stack/providers/adapters/inference/snowflake/config.py new file mode 100644 index 000000000..7987cb916 --- /dev/null +++ b/llama_stack/providers/adapters/inference/snowflake/config.py @@ -0,0 +1,21 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + + +from llama_models.schema_utils import json_schema_type +from pydantic import BaseModel, Field + + +@json_schema_type +class SnowflakeImplConfig(BaseModel): + url: str = Field( + default=None, + description="The URL for the Snowflake Cortex model serving endpoint", + ) + api_token: str = Field( + default=None, + description="The Snowflake Cortex API token", + ) diff --git a/llama_stack/providers/adapters/inference/snowflake/snowflake.py b/llama_stack/providers/adapters/inference/snowflake/snowflake.py new file mode 100644 index 000000000..3f7a72600 --- /dev/null +++ b/llama_stack/providers/adapters/inference/snowflake/snowflake.py @@ -0,0 +1,237 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import AsyncGenerator + +import httpx + +from llama_models.llama3.api.chat_format import ChatFormat + +from llama_models.llama3.api.datatypes import Message +from llama_models.llama3.api.tokenizer import Tokenizer + +from llama_stack.apis.inference import * # noqa: F403 +from llama_stack.distribution.request_headers import NeedsRequestProviderData +from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper +from llama_stack.providers.utils.inference.openai_compat import ( + get_sampling_options, + process_chat_completion_response, + process_chat_completion_stream_response, + process_completion_response, + process_completion_stream_response, +) +from llama_stack.providers.utils.inference.prompt_adapter import ( + chat_completion_request_to_prompt, + completion_request_to_prompt, +) + +from .config import SnowflakeImplConfig + + +SNOWFLAKE_SUPPORTED_MODELS = { + "Llama3.1-8B-Instruct": "snowflake-meta-Llama-3.1-8B-Instruct-Turbo", + "Llama3.1-70B-Instruct": "snowflake-meta-Llama-3.1-70B-Instruct-Turbo", + "Llama3.1-405B-Instruct": "snowflake-meta-Llama-3.1-405B-Instruct-Turbo", +} + + +class SnowflakeInferenceAdapter( + ModelRegistryHelper, Inference, NeedsRequestProviderData +): + + def __init__(self, config: SnowflakeImplConfig) -> None: + ModelRegistryHelper.__init__( + self, stack_to_provider_models_map=SNOWFLAKE_SUPPORTED_MODELS + ) + self.config = config + self.formatter = ChatFormat(Tokenizer.get_instance()) + + async def initialize(self) -> None: + pass + + async def shutdown(self) -> None: + pass + + async def completion( + self, + model: str, + content: InterleavedTextMedia, + sampling_params: Optional[SamplingParams] = SamplingParams(), + response_format: Optional[ResponseFormat] = None, + stream: Optional[bool] = False, + logprobs: Optional[LogProbConfig] = None, + ) -> AsyncGenerator: + request = CompletionRequest( + model=model, + content=content, + sampling_params=sampling_params, + response_format=response_format, + stream=stream, + logprobs=logprobs, + ) + if stream: + return self._stream_completion(request) + else: + return await self._nonstream_completion(request) + + def _get_cortex_headers( + self, + ): + snowflake_api_key = None + if self.config.api_key is not None: + snowflake_api_key = self.config.api_key + else: + provider_data = self.get_request_provider_data() + if provider_data is None or not provider_data.snowflake_api_key: + raise ValueError( + 'Pass Snowflake API Key in the header X-LlamaStack-ProviderData as { "snowflake_api_key": }' + ) + snowflake_api_key = provider_data.snowflake_api_key + + headers = { + "Accept": "text/stream", + "Content-Type": "application/json", + "Authorization": f'Snowflake Token="{snowflake_api_key}"', + } + + return headers + + def _get_cortex_client(self, timeout=None, concurrent_limit=1000): + + client = httpx.Client( + timeout=timeout, + limits=httpx.Limits( + max_connections=concurrent_limit, + max_keepalive_connections=concurrent_limit, + ), + ) + + return client + + async def _nonstream_completion( + self, request: CompletionRequest + ) -> ChatCompletionResponse: + params = self._get_params_for_completion(request) + r = self._get_cortex_client().post(**params) + return process_completion_response( + r, self.formatter + ) # TODO VALIDATE COMPLETION PROCESSOR + + async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator: + params = self._get_params_for_completion(request) + + # if we shift to TogetherAsyncClient, we won't need this wrapper + async def _to_async_generator(): + s = self._get_cortex_client().completions.create(**params) + for chunk in s: + yield chunk + + stream = _to_async_generator() + async for chunk in process_completion_stream_response(stream, self.formatter): + yield chunk + + def _build_options( + self, sampling_params: Optional[SamplingParams], fmt: ResponseFormat + ) -> dict: + options = get_sampling_options(sampling_params) + if fmt: + if fmt.type == ResponseFormatType.json_schema.value: + options["response_format"] = { + "type": "json_object", + "schema": fmt.json_schema, + } + elif fmt.type == ResponseFormatType.grammar.value: + raise NotImplementedError("Grammar response format not supported yet") + else: + raise ValueError(f"Unknown response format {fmt.type}") + + return options + + def _get_params_for_completion(self, request: CompletionRequest) -> dict: + return { + "model": self.map_to_provider_model(request.model), + "prompt": completion_request_to_prompt(request, self.formatter), + "stream": request.stream, + **self._build_options(request.sampling_params, request.response_format), + } + + async def chat_completion( + self, + model: str, + messages: List[Message], + sampling_params: Optional[SamplingParams] = SamplingParams(), + tools: Optional[List[ToolDefinition]] = None, + tool_choice: Optional[ToolChoice] = ToolChoice.auto, + tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, + response_format: Optional[ResponseFormat] = None, + stream: Optional[bool] = False, + logprobs: Optional[LogProbConfig] = None, + ) -> AsyncGenerator: + + request = ChatCompletionRequest( + model=model, + messages=messages, + sampling_params=sampling_params, + tools=tools or [], + tool_choice=tool_choice, + tool_prompt_format=tool_prompt_format, + response_format=response_format, + stream=stream, + logprobs=logprobs, + ) + + if stream: + return self._stream_chat_completion(request) + else: + return await self._nonstream_chat_completion(request) + + async def _nonstream_chat_completion( + self, request: ChatCompletionRequest + ) -> ChatCompletionResponse: + params = self._get_params(request) + r = self._get_cortex_client().post(**params) + return process_chat_completion_response(r, self.formatter) + + async def _stream_chat_completion( + self, request: ChatCompletionRequest + ) -> AsyncGenerator: + params = self._get_params(request) + + # if we shift to TogetherAsyncClient, we won't need this wrapper + async def _to_async_generator(): + s = self._get_cortex_client().post(**params) + for chunk in s: + yield chunk + + stream = _to_async_generator() + async for chunk in process_chat_completion_stream_response( + stream, self.formatter + ): + yield chunk + + # TODO UPDATE PARAM STRUCTURE + def _get_params(self, request: ChatCompletionRequest) -> dict: + return { + "headers": self._get_cortex_headers(), + "data": { + "model": self.map_to_provider_model(request.model), + "messages": [ + { + "content": chat_completion_request_to_prompt( + request, self.formatter + ) + } + ], + "stream": request.stream, + }, + } + + async def embeddings( + self, + model: str, + contents: List[InterleavedTextMedia], + ) -> EmbeddingsResponse: + raise NotImplementedError() diff --git a/llama_stack/providers/registry/inference.py b/llama_stack/providers/registry/inference.py index 88265f1b4..a7e788c83 100644 --- a/llama_stack/providers/registry/inference.py +++ b/llama_stack/providers/registry/inference.py @@ -97,6 +97,15 @@ def available_providers() -> List[ProviderSpec]: config_class="llama_stack.providers.adapters.inference.tgi.InferenceEndpointImplConfig", ), ), + remote_provider_spec( + api=Api.inference, + adapter=AdapterSpec( + adapter_type="snowflake", + pip_packages=["python-snowflake-snowpark"], + module="llama_stack.providers.adapters.inference.snowflake", + config_class="llama_stack.providers.adapters.inference.snowflake.SnowflakeImplConfig", + ), + ), remote_provider_spec( api=Api.inference, adapter=AdapterSpec(