feat: adding snowflake provider and template

This commit is contained in:
alejandro 2024-10-31 19:13:00 -04:00
parent 59d856363e
commit ecb395c751
5 changed files with 92 additions and 22 deletions

View file

@ -51,6 +51,7 @@ A Distribution is where APIs and Providers are assembled together to provide a c
| Meta Reference | Single Node | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
| Fireworks | Hosted | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | | |
| AWS Bedrock | Hosted | | :heavy_check_mark: | | :heavy_check_mark: | |
| Snowflake | Hosted | | :heavy_check_mark: | | |
| Together | Hosted | :heavy_check_mark: | :heavy_check_mark: | | :heavy_check_mark: | |
| Ollama | Single Node | | :heavy_check_mark: | | |
| TGI | Hosted and Single Node | | :heavy_check_mark: | | |

View file

@ -11,11 +11,11 @@ from pydantic import BaseModel, Field
@json_schema_type
class SnowflakeImplConfig(BaseModel):
url: str = Field(
account: str = Field(
default=None,
description="The URL for the Snowflake Cortex model serving endpoint",
description="The Snowflake Account ID for the Snowflake Cortex model serving endpoint",
)
api_token: str = Field(
api_key: str = Field(
default=None,
description="The Snowflake Cortex API token",
)

View file

@ -4,6 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
from typing import AsyncGenerator
import httpx
@ -18,8 +19,6 @@ from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.providers.utils.inference.openai_compat import (
get_sampling_options,
process_chat_completion_response,
process_chat_completion_stream_response,
process_completion_response,
process_completion_stream_response,
)
@ -32,9 +31,9 @@ from .config import SnowflakeImplConfig
SNOWFLAKE_SUPPORTED_MODELS = {
"Llama3.1-8B-Instruct": "snowflake-meta-Llama-3.1-8B-Instruct-Turbo",
"Llama3.1-70B-Instruct": "snowflake-meta-Llama-3.1-70B-Instruct-Turbo",
"Llama3.1-405B-Instruct": "snowflake-meta-Llama-3.1-405B-Instruct-Turbo",
"Llama3.1-8B-Instruct": "llama3.1-8b",
"Llama3.1-70B-Instruct": "llama3.1-70b",
"Llama3.1-405B-Instruct": "llama3.1-405b",
}
@ -99,7 +98,7 @@ class SnowflakeInferenceAdapter(
return headers
def _get_cortex_client(self, timeout=None, concurrent_limit=1000):
def _get_cortex_client(self, timeout=30, concurrent_limit=1000):
client = httpx.Client(
timeout=timeout,
@ -111,6 +110,18 @@ class SnowflakeInferenceAdapter(
return client
def _get_cortex_async_client(self, timeout=30, concurrent_limit=1000):
client = httpx.AsyncClient(
timeout=timeout,
limits=httpx.Limits(
max_connections=concurrent_limit,
max_keepalive_connections=concurrent_limit,
),
)
return client
async def _nonstream_completion(
self, request: CompletionRequest
) -> ChatCompletionResponse:
@ -125,7 +136,7 @@ class SnowflakeInferenceAdapter(
# if we shift to TogetherAsyncClient, we won't need this wrapper
async def _to_async_generator():
s = self._get_cortex_client().completions.create(**params)
s = self._get_cortex_client().post(**params)
for chunk in s:
yield chunk
@ -193,7 +204,7 @@ class SnowflakeInferenceAdapter(
) -> ChatCompletionResponse:
params = self._get_params(request)
r = self._get_cortex_client().post(**params)
return process_chat_completion_response(r, self.formatter)
return self._process_nonstream_snowflake_response(r.text)
async def _stream_chat_completion(
self, request: ChatCompletionRequest
@ -202,21 +213,29 @@ class SnowflakeInferenceAdapter(
# if we shift to TogetherAsyncClient, we won't need this wrapper
async def _to_async_generator():
s = self._get_cortex_client().post(**params)
for chunk in s:
yield chunk
async with self._get_cortex_async_client() as client:
async with client.stream("POST", **params) as response:
async for line in response.aiter_lines():
if line.strip(): # Check if line is not empty
yield line
stream = _to_async_generator()
async for chunk in process_chat_completion_stream_response(
stream, self.formatter
):
yield chunk
# TODO UPDATE PARAM STRUCTURE
async for chunk in stream:
clean_chunk = self._process_snowflake_stream_response(chunk)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=clean_chunk,
stop_reason=None,
)
)
def _get_params(self, request: ChatCompletionRequest) -> dict:
return {
"url": self._get_cortex_url(),
"headers": self._get_cortex_headers(),
"data": {
"json": {
"model": self.map_to_provider_model(request.model),
"messages": [
{
@ -225,7 +244,6 @@ class SnowflakeInferenceAdapter(
)
}
],
"stream": request.stream,
},
}
@ -235,3 +253,45 @@ class SnowflakeInferenceAdapter(
contents: List[InterleavedTextMedia],
) -> EmbeddingsResponse:
raise NotImplementedError()
def _process_nonstream_snowflake_response(self, response_str):
json_objects = response_str.split("\ndata: ")
json_list = []
# Iterate over each JSON object
for obj in json_objects:
obj = obj.strip()
if obj:
# Remove the 'data: ' prefix if it exists
if obj.startswith("data: "):
obj = obj[6:]
# Load the JSON object into a Python dictionary
json_dict = json.loads(obj, strict=False)
# Append the JSON dictionary to the list
json_list.append(json_dict)
completion = ""
choices = {}
for chunk in json_list:
choices = chunk["choices"][0]
if "content" in choices["delta"].keys():
completion += choices["delta"]["content"]
return completion
def _process_snowflake_stream_response(self, response_str):
if not response_str.startswith("data: "):
return ""
try:
json_dict = json.loads(response_str[6:])
return json_dict["choices"][0]["delta"].get("content", "")
except (json.JSONDecodeError, KeyError, IndexError):
return ""
def _get_cortex_url(self):
account_id = self.config.account
cortex_endpoint = f"https://{account_id}.snowflakecomputing.com/api/v2/cortex/inference:complete"
return cortex_endpoint

View file

@ -101,7 +101,7 @@ def available_providers() -> List[ProviderSpec]:
api=Api.inference,
adapter=AdapterSpec(
adapter_type="snowflake",
pip_packages=["python-snowflake-snowpark"],
pip_packages=["snowflake"],
module="llama_stack.providers.adapters.inference.snowflake",
config_class="llama_stack.providers.adapters.inference.snowflake.SnowflakeImplConfig",
),

View file

@ -0,0 +1,9 @@
name: snowflake
distribution_spec:
description: Use Snowflake for running LLM inference
providers:
inference: remote::snowflake
memory: meta-reference
safety: meta-reference
agents: meta-reference
telemetry: meta-reference