Add a RoutableProvider protocol, support for multiple routing keys (#163)

* Update configure.py to use multiple routing keys for safety
* Refactor distribution/datatypes into a providers/datatypes
* Cleanup
This commit is contained in:
Ashwin Bharambe 2024-09-30 17:30:21 -07:00 committed by GitHub
parent 73decb3781
commit eb2d8a31a5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
24 changed files with 600 additions and 577 deletions

View file

@ -10,7 +10,6 @@ from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import Message, StopReason
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_models.sku_list import resolve_model
from together import Together
@ -19,9 +18,11 @@ from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.providers.utils.inference.augment_messages import (
augment_messages_for_tools,
)
from llama_stack.providers.utils.inference.routable import RoutableProviderForModels
from .config import TogetherImplConfig
TOGETHER_SUPPORTED_MODELS = {
"Llama3.1-8B-Instruct": "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo",
"Llama3.1-70B-Instruct": "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
@ -32,8 +33,13 @@ TOGETHER_SUPPORTED_MODELS = {
}
class TogetherInferenceAdapter(Inference, NeedsRequestProviderData):
class TogetherInferenceAdapter(
Inference, NeedsRequestProviderData, RoutableProviderForModels
):
def __init__(self, config: TogetherImplConfig) -> None:
RoutableProviderForModels.__init__(
self, stack_to_provider_models_map=TOGETHER_SUPPORTED_MODELS
)
self.config = config
tokenizer = Tokenizer.get_instance()
self.formatter = ChatFormat(tokenizer)
@ -69,18 +75,6 @@ class TogetherInferenceAdapter(Inference, NeedsRequestProviderData):
return together_messages
def resolve_together_model(self, model_name: str) -> str:
model = resolve_model(model_name)
assert (
model is not None
and model.descriptor(shorten_default_variant=True)
in TOGETHER_SUPPORTED_MODELS
), f"Unsupported model: {model_name}, use one of the supported models: {','.join(TOGETHER_SUPPORTED_MODELS.keys())}"
return TOGETHER_SUPPORTED_MODELS.get(
model.descriptor(shorten_default_variant=True)
)
def get_together_chat_options(self, request: ChatCompletionRequest) -> dict:
options = {}
if request.sampling_params is not None:
@ -103,12 +97,15 @@ class TogetherInferenceAdapter(Inference, NeedsRequestProviderData):
) -> AsyncGenerator:
together_api_key = None
provider_data = self.get_request_provider_data()
if provider_data is None or not provider_data.together_api_key:
raise ValueError(
'Pass Together API Key in the header X-LlamaStack-ProviderData as { "together_api_key": <your api key>}'
)
together_api_key = provider_data.together_api_key
if self.config.api_key is not None:
together_api_key = self.config.api_key
else:
provider_data = self.get_request_provider_data()
if provider_data is None or not provider_data.together_api_key:
raise ValueError(
'Pass Together API Key in the header X-LlamaStack-ProviderData as { "together_api_key": <your api key>}'
)
together_api_key = provider_data.together_api_key
client = Together(api_key=together_api_key)
# wrapper request to make it easier to pass around (internal only, not exposed to API)
@ -125,7 +122,7 @@ class TogetherInferenceAdapter(Inference, NeedsRequestProviderData):
# accumulate sampling params and other options to pass to together
options = self.get_together_chat_options(request)
together_model = self.resolve_together_model(request.model)
together_model = self.map_to_provider_model(request.model)
messages = augment_messages_for_tools(request)
if not request.stream:
@ -171,17 +168,10 @@ class TogetherInferenceAdapter(Inference, NeedsRequestProviderData):
stream=True,
**options,
):
if chunk.choices[0].finish_reason:
if (
stop_reason is None and chunk.choices[0].finish_reason == "stop"
) or (
stop_reason is None and chunk.choices[0].finish_reason == "eos"
):
if finish_reason := chunk.choices[0].finish_reason:
if stop_reason is None and finish_reason in ["stop", "eos"]:
stop_reason = StopReason.end_of_turn
elif (
stop_reason is None
and chunk.choices[0].finish_reason == "length"
):
elif stop_reason is None and finish_reason == "length":
stop_reason = StopReason.out_of_tokens
break