mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
feat: adding snowflake provider and template
This commit is contained in:
parent
59d856363e
commit
ecb395c751
5 changed files with 92 additions and 22 deletions
|
@ -51,6 +51,7 @@ A Distribution is where APIs and Providers are assembled together to provide a c
|
|||
| Meta Reference | Single Node | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
|
||||
| Fireworks | Hosted | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | | |
|
||||
| AWS Bedrock | Hosted | | :heavy_check_mark: | | :heavy_check_mark: | |
|
||||
| Snowflake | Hosted | | :heavy_check_mark: | | |
|
||||
| Together | Hosted | :heavy_check_mark: | :heavy_check_mark: | | :heavy_check_mark: | |
|
||||
| Ollama | Single Node | | :heavy_check_mark: | | |
|
||||
| TGI | Hosted and Single Node | | :heavy_check_mark: | | |
|
||||
|
|
|
@ -11,11 +11,11 @@ from pydantic import BaseModel, Field
|
|||
|
||||
@json_schema_type
|
||||
class SnowflakeImplConfig(BaseModel):
|
||||
url: str = Field(
|
||||
account: str = Field(
|
||||
default=None,
|
||||
description="The URL for the Snowflake Cortex model serving endpoint",
|
||||
description="The Snowflake Account ID for the Snowflake Cortex model serving endpoint",
|
||||
)
|
||||
api_token: str = Field(
|
||||
api_key: str = Field(
|
||||
default=None,
|
||||
description="The Snowflake Cortex API token",
|
||||
)
|
||||
|
|
|
@ -4,6 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
from typing import AsyncGenerator
|
||||
|
||||
import httpx
|
||||
|
@ -18,8 +19,6 @@ from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
|||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
get_sampling_options,
|
||||
process_chat_completion_response,
|
||||
process_chat_completion_stream_response,
|
||||
process_completion_response,
|
||||
process_completion_stream_response,
|
||||
)
|
||||
|
@ -32,9 +31,9 @@ from .config import SnowflakeImplConfig
|
|||
|
||||
|
||||
SNOWFLAKE_SUPPORTED_MODELS = {
|
||||
"Llama3.1-8B-Instruct": "snowflake-meta-Llama-3.1-8B-Instruct-Turbo",
|
||||
"Llama3.1-70B-Instruct": "snowflake-meta-Llama-3.1-70B-Instruct-Turbo",
|
||||
"Llama3.1-405B-Instruct": "snowflake-meta-Llama-3.1-405B-Instruct-Turbo",
|
||||
"Llama3.1-8B-Instruct": "llama3.1-8b",
|
||||
"Llama3.1-70B-Instruct": "llama3.1-70b",
|
||||
"Llama3.1-405B-Instruct": "llama3.1-405b",
|
||||
}
|
||||
|
||||
|
||||
|
@ -99,7 +98,7 @@ class SnowflakeInferenceAdapter(
|
|||
|
||||
return headers
|
||||
|
||||
def _get_cortex_client(self, timeout=None, concurrent_limit=1000):
|
||||
def _get_cortex_client(self, timeout=30, concurrent_limit=1000):
|
||||
|
||||
client = httpx.Client(
|
||||
timeout=timeout,
|
||||
|
@ -111,6 +110,18 @@ class SnowflakeInferenceAdapter(
|
|||
|
||||
return client
|
||||
|
||||
def _get_cortex_async_client(self, timeout=30, concurrent_limit=1000):
|
||||
|
||||
client = httpx.AsyncClient(
|
||||
timeout=timeout,
|
||||
limits=httpx.Limits(
|
||||
max_connections=concurrent_limit,
|
||||
max_keepalive_connections=concurrent_limit,
|
||||
),
|
||||
)
|
||||
|
||||
return client
|
||||
|
||||
async def _nonstream_completion(
|
||||
self, request: CompletionRequest
|
||||
) -> ChatCompletionResponse:
|
||||
|
@ -125,7 +136,7 @@ class SnowflakeInferenceAdapter(
|
|||
|
||||
# if we shift to TogetherAsyncClient, we won't need this wrapper
|
||||
async def _to_async_generator():
|
||||
s = self._get_cortex_client().completions.create(**params)
|
||||
s = self._get_cortex_client().post(**params)
|
||||
for chunk in s:
|
||||
yield chunk
|
||||
|
||||
|
@ -193,7 +204,7 @@ class SnowflakeInferenceAdapter(
|
|||
) -> ChatCompletionResponse:
|
||||
params = self._get_params(request)
|
||||
r = self._get_cortex_client().post(**params)
|
||||
return process_chat_completion_response(r, self.formatter)
|
||||
return self._process_nonstream_snowflake_response(r.text)
|
||||
|
||||
async def _stream_chat_completion(
|
||||
self, request: ChatCompletionRequest
|
||||
|
@ -202,21 +213,29 @@ class SnowflakeInferenceAdapter(
|
|||
|
||||
# if we shift to TogetherAsyncClient, we won't need this wrapper
|
||||
async def _to_async_generator():
|
||||
s = self._get_cortex_client().post(**params)
|
||||
for chunk in s:
|
||||
yield chunk
|
||||
async with self._get_cortex_async_client() as client:
|
||||
async with client.stream("POST", **params) as response:
|
||||
async for line in response.aiter_lines():
|
||||
if line.strip(): # Check if line is not empty
|
||||
yield line
|
||||
|
||||
stream = _to_async_generator()
|
||||
async for chunk in process_chat_completion_stream_response(
|
||||
stream, self.formatter
|
||||
):
|
||||
yield chunk
|
||||
|
||||
# TODO UPDATE PARAM STRUCTURE
|
||||
async for chunk in stream:
|
||||
clean_chunk = self._process_snowflake_stream_response(chunk)
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=clean_chunk,
|
||||
stop_reason=None,
|
||||
)
|
||||
)
|
||||
|
||||
def _get_params(self, request: ChatCompletionRequest) -> dict:
|
||||
return {
|
||||
"url": self._get_cortex_url(),
|
||||
"headers": self._get_cortex_headers(),
|
||||
"data": {
|
||||
"json": {
|
||||
"model": self.map_to_provider_model(request.model),
|
||||
"messages": [
|
||||
{
|
||||
|
@ -225,7 +244,6 @@ class SnowflakeInferenceAdapter(
|
|||
)
|
||||
}
|
||||
],
|
||||
"stream": request.stream,
|
||||
},
|
||||
}
|
||||
|
||||
|
@ -235,3 +253,45 @@ class SnowflakeInferenceAdapter(
|
|||
contents: List[InterleavedTextMedia],
|
||||
) -> EmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
def _process_nonstream_snowflake_response(self, response_str):
|
||||
|
||||
json_objects = response_str.split("\ndata: ")
|
||||
json_list = []
|
||||
|
||||
# Iterate over each JSON object
|
||||
for obj in json_objects:
|
||||
obj = obj.strip()
|
||||
if obj:
|
||||
# Remove the 'data: ' prefix if it exists
|
||||
if obj.startswith("data: "):
|
||||
obj = obj[6:]
|
||||
# Load the JSON object into a Python dictionary
|
||||
json_dict = json.loads(obj, strict=False)
|
||||
# Append the JSON dictionary to the list
|
||||
json_list.append(json_dict)
|
||||
|
||||
completion = ""
|
||||
choices = {}
|
||||
for chunk in json_list:
|
||||
choices = chunk["choices"][0]
|
||||
|
||||
if "content" in choices["delta"].keys():
|
||||
completion += choices["delta"]["content"]
|
||||
|
||||
return completion
|
||||
|
||||
def _process_snowflake_stream_response(self, response_str):
|
||||
if not response_str.startswith("data: "):
|
||||
return ""
|
||||
|
||||
try:
|
||||
json_dict = json.loads(response_str[6:])
|
||||
return json_dict["choices"][0]["delta"].get("content", "")
|
||||
except (json.JSONDecodeError, KeyError, IndexError):
|
||||
return ""
|
||||
|
||||
def _get_cortex_url(self):
|
||||
account_id = self.config.account
|
||||
cortex_endpoint = f"https://{account_id}.snowflakecomputing.com/api/v2/cortex/inference:complete"
|
||||
return cortex_endpoint
|
||||
|
|
|
@ -101,7 +101,7 @@ def available_providers() -> List[ProviderSpec]:
|
|||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="snowflake",
|
||||
pip_packages=["python-snowflake-snowpark"],
|
||||
pip_packages=["snowflake"],
|
||||
module="llama_stack.providers.adapters.inference.snowflake",
|
||||
config_class="llama_stack.providers.adapters.inference.snowflake.SnowflakeImplConfig",
|
||||
),
|
||||
|
|
9
llama_stack/templates/snowflake/build.yaml
Normal file
9
llama_stack/templates/snowflake/build.yaml
Normal file
|
@ -0,0 +1,9 @@
|
|||
name: snowflake
|
||||
distribution_spec:
|
||||
description: Use Snowflake for running LLM inference
|
||||
providers:
|
||||
inference: remote::snowflake
|
||||
memory: meta-reference
|
||||
safety: meta-reference
|
||||
agents: meta-reference
|
||||
telemetry: meta-reference
|
Loading…
Add table
Add a link
Reference in a new issue