feat: initial implementation of snowflake provider + distro

This commit is contained in:
alejandro 2024-10-31 13:00:22 -04:00
parent 4aa1bf6a60
commit 59d856363e
5 changed files with 293 additions and 0 deletions

View file

@ -0,0 +1,9 @@
name: snowflake
distribution_spec:
description: Use Snowflake for running LLM inference
providers:
inference: remote::snowflake
memory: meta-reference
safety: meta-reference
agents: meta-reference
telemetry: meta-reference

View file

@ -0,0 +1,17 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .config import SnowflakeImplConfig
from .snowflake import SnowflakeInferenceAdapter
async def get_adapter_impl(config: SnowflakeImplConfig, _deps):
assert isinstance(
config, SnowflakeImplConfig
), f"Unexpected config type: {type(config)}"
impl = SnowflakeInferenceAdapter(config)
await impl.initialize()
return impl

View file

@ -0,0 +1,21 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field
@json_schema_type
class SnowflakeImplConfig(BaseModel):
url: str = Field(
default=None,
description="The URL for the Snowflake Cortex model serving endpoint",
)
api_token: str = Field(
default=None,
description="The Snowflake Cortex API token",
)

View file

@ -0,0 +1,237 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import AsyncGenerator
import httpx
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import Message
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.providers.utils.inference.openai_compat import (
get_sampling_options,
process_chat_completion_response,
process_chat_completion_stream_response,
process_completion_response,
process_completion_stream_response,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt,
completion_request_to_prompt,
)
from .config import SnowflakeImplConfig
SNOWFLAKE_SUPPORTED_MODELS = {
"Llama3.1-8B-Instruct": "snowflake-meta-Llama-3.1-8B-Instruct-Turbo",
"Llama3.1-70B-Instruct": "snowflake-meta-Llama-3.1-70B-Instruct-Turbo",
"Llama3.1-405B-Instruct": "snowflake-meta-Llama-3.1-405B-Instruct-Turbo",
}
class SnowflakeInferenceAdapter(
ModelRegistryHelper, Inference, NeedsRequestProviderData
):
def __init__(self, config: SnowflakeImplConfig) -> None:
ModelRegistryHelper.__init__(
self, stack_to_provider_models_map=SNOWFLAKE_SUPPORTED_MODELS
)
self.config = config
self.formatter = ChatFormat(Tokenizer.get_instance())
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
pass
async def completion(
self,
model: str,
content: InterleavedTextMedia,
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
request = CompletionRequest(
model=model,
content=content,
sampling_params=sampling_params,
response_format=response_format,
stream=stream,
logprobs=logprobs,
)
if stream:
return self._stream_completion(request)
else:
return await self._nonstream_completion(request)
def _get_cortex_headers(
self,
):
snowflake_api_key = None
if self.config.api_key is not None:
snowflake_api_key = self.config.api_key
else:
provider_data = self.get_request_provider_data()
if provider_data is None or not provider_data.snowflake_api_key:
raise ValueError(
'Pass Snowflake API Key in the header X-LlamaStack-ProviderData as { "snowflake_api_key": <your api key>}'
)
snowflake_api_key = provider_data.snowflake_api_key
headers = {
"Accept": "text/stream",
"Content-Type": "application/json",
"Authorization": f'Snowflake Token="{snowflake_api_key}"',
}
return headers
def _get_cortex_client(self, timeout=None, concurrent_limit=1000):
client = httpx.Client(
timeout=timeout,
limits=httpx.Limits(
max_connections=concurrent_limit,
max_keepalive_connections=concurrent_limit,
),
)
return client
async def _nonstream_completion(
self, request: CompletionRequest
) -> ChatCompletionResponse:
params = self._get_params_for_completion(request)
r = self._get_cortex_client().post(**params)
return process_completion_response(
r, self.formatter
) # TODO VALIDATE COMPLETION PROCESSOR
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
params = self._get_params_for_completion(request)
# if we shift to TogetherAsyncClient, we won't need this wrapper
async def _to_async_generator():
s = self._get_cortex_client().completions.create(**params)
for chunk in s:
yield chunk
stream = _to_async_generator()
async for chunk in process_completion_stream_response(stream, self.formatter):
yield chunk
def _build_options(
self, sampling_params: Optional[SamplingParams], fmt: ResponseFormat
) -> dict:
options = get_sampling_options(sampling_params)
if fmt:
if fmt.type == ResponseFormatType.json_schema.value:
options["response_format"] = {
"type": "json_object",
"schema": fmt.json_schema,
}
elif fmt.type == ResponseFormatType.grammar.value:
raise NotImplementedError("Grammar response format not supported yet")
else:
raise ValueError(f"Unknown response format {fmt.type}")
return options
def _get_params_for_completion(self, request: CompletionRequest) -> dict:
return {
"model": self.map_to_provider_model(request.model),
"prompt": completion_request_to_prompt(request, self.formatter),
"stream": request.stream,
**self._build_options(request.sampling_params, request.response_format),
}
async def chat_completion(
self,
model: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
request = ChatCompletionRequest(
model=model,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
response_format=response_format,
stream=stream,
logprobs=logprobs,
)
if stream:
return self._stream_chat_completion(request)
else:
return await self._nonstream_chat_completion(request)
async def _nonstream_chat_completion(
self, request: ChatCompletionRequest
) -> ChatCompletionResponse:
params = self._get_params(request)
r = self._get_cortex_client().post(**params)
return process_chat_completion_response(r, self.formatter)
async def _stream_chat_completion(
self, request: ChatCompletionRequest
) -> AsyncGenerator:
params = self._get_params(request)
# if we shift to TogetherAsyncClient, we won't need this wrapper
async def _to_async_generator():
s = self._get_cortex_client().post(**params)
for chunk in s:
yield chunk
stream = _to_async_generator()
async for chunk in process_chat_completion_stream_response(
stream, self.formatter
):
yield chunk
# TODO UPDATE PARAM STRUCTURE
def _get_params(self, request: ChatCompletionRequest) -> dict:
return {
"headers": self._get_cortex_headers(),
"data": {
"model": self.map_to_provider_model(request.model),
"messages": [
{
"content": chat_completion_request_to_prompt(
request, self.formatter
)
}
],
"stream": request.stream,
},
}
async def embeddings(
self,
model: str,
contents: List[InterleavedTextMedia],
) -> EmbeddingsResponse:
raise NotImplementedError()

View file

@ -97,6 +97,15 @@ def available_providers() -> List[ProviderSpec]:
config_class="llama_stack.providers.adapters.inference.tgi.InferenceEndpointImplConfig",
),
),
remote_provider_spec(
api=Api.inference,
adapter=AdapterSpec(
adapter_type="snowflake",
pip_packages=["python-snowflake-snowpark"],
module="llama_stack.providers.adapters.inference.snowflake",
config_class="llama_stack.providers.adapters.inference.snowflake.SnowflakeImplConfig",
),
),
remote_provider_spec(
api=Api.inference,
adapter=AdapterSpec(