diff --git a/README.md b/README.md index 251b81513..c75b30a5c 100644 --- a/README.md +++ b/README.md @@ -51,6 +51,7 @@ A Distribution is where APIs and Providers are assembled together to provide a c | Meta Reference | Single Node | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | | Fireworks | Hosted | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | | | | AWS Bedrock | Hosted | | :heavy_check_mark: | | :heavy_check_mark: | | +| Snowflake | Hosted | | :heavy_check_mark: | | | | Together | Hosted | :heavy_check_mark: | :heavy_check_mark: | | :heavy_check_mark: | | | Ollama | Single Node | | :heavy_check_mark: | | | | TGI | Hosted and Single Node | | :heavy_check_mark: | | | diff --git a/llama_stack/providers/adapters/inference/snowflake/config.py b/llama_stack/providers/adapters/inference/snowflake/config.py index 7987cb916..35e27b75f 100644 --- a/llama_stack/providers/adapters/inference/snowflake/config.py +++ b/llama_stack/providers/adapters/inference/snowflake/config.py @@ -11,11 +11,11 @@ from pydantic import BaseModel, Field @json_schema_type class SnowflakeImplConfig(BaseModel): - url: str = Field( + account: str = Field( default=None, - description="The URL for the Snowflake Cortex model serving endpoint", + description="The Snowflake Account ID for the Snowflake Cortex model serving endpoint", ) - api_token: str = Field( + api_key: str = Field( default=None, description="The Snowflake Cortex API token", ) diff --git a/llama_stack/providers/adapters/inference/snowflake/snowflake.py b/llama_stack/providers/adapters/inference/snowflake/snowflake.py index 3f7a72600..01ace6b7e 100644 --- a/llama_stack/providers/adapters/inference/snowflake/snowflake.py +++ b/llama_stack/providers/adapters/inference/snowflake/snowflake.py @@ -4,6 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import json from typing import AsyncGenerator import httpx @@ -18,8 +19,6 @@ from llama_stack.distribution.request_headers import NeedsRequestProviderData from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper from llama_stack.providers.utils.inference.openai_compat import ( get_sampling_options, - process_chat_completion_response, - process_chat_completion_stream_response, process_completion_response, process_completion_stream_response, ) @@ -32,9 +31,9 @@ from .config import SnowflakeImplConfig SNOWFLAKE_SUPPORTED_MODELS = { - "Llama3.1-8B-Instruct": "snowflake-meta-Llama-3.1-8B-Instruct-Turbo", - "Llama3.1-70B-Instruct": "snowflake-meta-Llama-3.1-70B-Instruct-Turbo", - "Llama3.1-405B-Instruct": "snowflake-meta-Llama-3.1-405B-Instruct-Turbo", + "Llama3.1-8B-Instruct": "llama3.1-8b", + "Llama3.1-70B-Instruct": "llama3.1-70b", + "Llama3.1-405B-Instruct": "llama3.1-405b", } @@ -99,7 +98,7 @@ class SnowflakeInferenceAdapter( return headers - def _get_cortex_client(self, timeout=None, concurrent_limit=1000): + def _get_cortex_client(self, timeout=30, concurrent_limit=1000): client = httpx.Client( timeout=timeout, @@ -111,6 +110,18 @@ class SnowflakeInferenceAdapter( return client + def _get_cortex_async_client(self, timeout=30, concurrent_limit=1000): + + client = httpx.AsyncClient( + timeout=timeout, + limits=httpx.Limits( + max_connections=concurrent_limit, + max_keepalive_connections=concurrent_limit, + ), + ) + + return client + async def _nonstream_completion( self, request: CompletionRequest ) -> ChatCompletionResponse: @@ -125,7 +136,7 @@ class SnowflakeInferenceAdapter( # if we shift to TogetherAsyncClient, we won't need this wrapper async def _to_async_generator(): - s = self._get_cortex_client().completions.create(**params) + s = self._get_cortex_client().post(**params) for chunk in s: yield chunk @@ -193,7 +204,7 @@ class SnowflakeInferenceAdapter( ) -> ChatCompletionResponse: params = self._get_params(request) r = self._get_cortex_client().post(**params) - return process_chat_completion_response(r, self.formatter) + return self._process_nonstream_snowflake_response(r.text) async def _stream_chat_completion( self, request: ChatCompletionRequest @@ -202,21 +213,29 @@ class SnowflakeInferenceAdapter( # if we shift to TogetherAsyncClient, we won't need this wrapper async def _to_async_generator(): - s = self._get_cortex_client().post(**params) - for chunk in s: - yield chunk + async with self._get_cortex_async_client() as client: + async with client.stream("POST", **params) as response: + async for line in response.aiter_lines(): + if line.strip(): # Check if line is not empty + yield line stream = _to_async_generator() - async for chunk in process_chat_completion_stream_response( - stream, self.formatter - ): - yield chunk - # TODO UPDATE PARAM STRUCTURE + async for chunk in stream: + clean_chunk = self._process_snowflake_stream_response(chunk) + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=ChatCompletionResponseEventType.progress, + delta=clean_chunk, + stop_reason=None, + ) + ) + def _get_params(self, request: ChatCompletionRequest) -> dict: return { + "url": self._get_cortex_url(), "headers": self._get_cortex_headers(), - "data": { + "json": { "model": self.map_to_provider_model(request.model), "messages": [ { @@ -225,7 +244,6 @@ class SnowflakeInferenceAdapter( ) } ], - "stream": request.stream, }, } @@ -235,3 +253,45 @@ class SnowflakeInferenceAdapter( contents: List[InterleavedTextMedia], ) -> EmbeddingsResponse: raise NotImplementedError() + + def _process_nonstream_snowflake_response(self, response_str): + + json_objects = response_str.split("\ndata: ") + json_list = [] + + # Iterate over each JSON object + for obj in json_objects: + obj = obj.strip() + if obj: + # Remove the 'data: ' prefix if it exists + if obj.startswith("data: "): + obj = obj[6:] + # Load the JSON object into a Python dictionary + json_dict = json.loads(obj, strict=False) + # Append the JSON dictionary to the list + json_list.append(json_dict) + + completion = "" + choices = {} + for chunk in json_list: + choices = chunk["choices"][0] + + if "content" in choices["delta"].keys(): + completion += choices["delta"]["content"] + + return completion + + def _process_snowflake_stream_response(self, response_str): + if not response_str.startswith("data: "): + return "" + + try: + json_dict = json.loads(response_str[6:]) + return json_dict["choices"][0]["delta"].get("content", "") + except (json.JSONDecodeError, KeyError, IndexError): + return "" + + def _get_cortex_url(self): + account_id = self.config.account + cortex_endpoint = f"https://{account_id}.snowflakecomputing.com/api/v2/cortex/inference:complete" + return cortex_endpoint diff --git a/llama_stack/providers/registry/inference.py b/llama_stack/providers/registry/inference.py index a7e788c83..0192812cb 100644 --- a/llama_stack/providers/registry/inference.py +++ b/llama_stack/providers/registry/inference.py @@ -101,7 +101,7 @@ def available_providers() -> List[ProviderSpec]: api=Api.inference, adapter=AdapterSpec( adapter_type="snowflake", - pip_packages=["python-snowflake-snowpark"], + pip_packages=["snowflake"], module="llama_stack.providers.adapters.inference.snowflake", config_class="llama_stack.providers.adapters.inference.snowflake.SnowflakeImplConfig", ), diff --git a/llama_stack/templates/snowflake/build.yaml b/llama_stack/templates/snowflake/build.yaml new file mode 100644 index 000000000..ae9ebb4ae --- /dev/null +++ b/llama_stack/templates/snowflake/build.yaml @@ -0,0 +1,9 @@ +name: snowflake +distribution_spec: + description: Use Snowflake for running LLM inference + providers: + inference: remote::snowflake + memory: meta-reference + safety: meta-reference + agents: meta-reference + telemetry: meta-reference