forked from phoenix-oss/llama-stack-mirror
Add a RoutableProvider protocol, support for multiple routing keys (#163)
* Update configure.py to use multiple routing keys for safety * Refactor distribution/datatypes into a providers/datatypes * Cleanup
This commit is contained in:
parent
73decb3781
commit
eb2d8a31a5
24 changed files with 600 additions and 577 deletions
|
@ -10,7 +10,6 @@ from llama_models.llama3.api.chat_format import ChatFormat
|
|||
|
||||
from llama_models.llama3.api.datatypes import Message, StopReason
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
from llama_models.sku_list import resolve_model
|
||||
|
||||
from together import Together
|
||||
|
||||
|
@ -19,9 +18,11 @@ from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
|||
from llama_stack.providers.utils.inference.augment_messages import (
|
||||
augment_messages_for_tools,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.routable import RoutableProviderForModels
|
||||
|
||||
from .config import TogetherImplConfig
|
||||
|
||||
|
||||
TOGETHER_SUPPORTED_MODELS = {
|
||||
"Llama3.1-8B-Instruct": "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo",
|
||||
"Llama3.1-70B-Instruct": "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
|
||||
|
@ -32,8 +33,13 @@ TOGETHER_SUPPORTED_MODELS = {
|
|||
}
|
||||
|
||||
|
||||
class TogetherInferenceAdapter(Inference, NeedsRequestProviderData):
|
||||
class TogetherInferenceAdapter(
|
||||
Inference, NeedsRequestProviderData, RoutableProviderForModels
|
||||
):
|
||||
def __init__(self, config: TogetherImplConfig) -> None:
|
||||
RoutableProviderForModels.__init__(
|
||||
self, stack_to_provider_models_map=TOGETHER_SUPPORTED_MODELS
|
||||
)
|
||||
self.config = config
|
||||
tokenizer = Tokenizer.get_instance()
|
||||
self.formatter = ChatFormat(tokenizer)
|
||||
|
@ -69,18 +75,6 @@ class TogetherInferenceAdapter(Inference, NeedsRequestProviderData):
|
|||
|
||||
return together_messages
|
||||
|
||||
def resolve_together_model(self, model_name: str) -> str:
|
||||
model = resolve_model(model_name)
|
||||
assert (
|
||||
model is not None
|
||||
and model.descriptor(shorten_default_variant=True)
|
||||
in TOGETHER_SUPPORTED_MODELS
|
||||
), f"Unsupported model: {model_name}, use one of the supported models: {','.join(TOGETHER_SUPPORTED_MODELS.keys())}"
|
||||
|
||||
return TOGETHER_SUPPORTED_MODELS.get(
|
||||
model.descriptor(shorten_default_variant=True)
|
||||
)
|
||||
|
||||
def get_together_chat_options(self, request: ChatCompletionRequest) -> dict:
|
||||
options = {}
|
||||
if request.sampling_params is not None:
|
||||
|
@ -103,12 +97,15 @@ class TogetherInferenceAdapter(Inference, NeedsRequestProviderData):
|
|||
) -> AsyncGenerator:
|
||||
|
||||
together_api_key = None
|
||||
provider_data = self.get_request_provider_data()
|
||||
if provider_data is None or not provider_data.together_api_key:
|
||||
raise ValueError(
|
||||
'Pass Together API Key in the header X-LlamaStack-ProviderData as { "together_api_key": <your api key>}'
|
||||
)
|
||||
together_api_key = provider_data.together_api_key
|
||||
if self.config.api_key is not None:
|
||||
together_api_key = self.config.api_key
|
||||
else:
|
||||
provider_data = self.get_request_provider_data()
|
||||
if provider_data is None or not provider_data.together_api_key:
|
||||
raise ValueError(
|
||||
'Pass Together API Key in the header X-LlamaStack-ProviderData as { "together_api_key": <your api key>}'
|
||||
)
|
||||
together_api_key = provider_data.together_api_key
|
||||
|
||||
client = Together(api_key=together_api_key)
|
||||
# wrapper request to make it easier to pass around (internal only, not exposed to API)
|
||||
|
@ -125,7 +122,7 @@ class TogetherInferenceAdapter(Inference, NeedsRequestProviderData):
|
|||
|
||||
# accumulate sampling params and other options to pass to together
|
||||
options = self.get_together_chat_options(request)
|
||||
together_model = self.resolve_together_model(request.model)
|
||||
together_model = self.map_to_provider_model(request.model)
|
||||
messages = augment_messages_for_tools(request)
|
||||
|
||||
if not request.stream:
|
||||
|
@ -171,17 +168,10 @@ class TogetherInferenceAdapter(Inference, NeedsRequestProviderData):
|
|||
stream=True,
|
||||
**options,
|
||||
):
|
||||
if chunk.choices[0].finish_reason:
|
||||
if (
|
||||
stop_reason is None and chunk.choices[0].finish_reason == "stop"
|
||||
) or (
|
||||
stop_reason is None and chunk.choices[0].finish_reason == "eos"
|
||||
):
|
||||
if finish_reason := chunk.choices[0].finish_reason:
|
||||
if stop_reason is None and finish_reason in ["stop", "eos"]:
|
||||
stop_reason = StopReason.end_of_turn
|
||||
elif (
|
||||
stop_reason is None
|
||||
and chunk.choices[0].finish_reason == "length"
|
||||
):
|
||||
elif stop_reason is None and finish_reason == "length":
|
||||
stop_reason = StopReason.out_of_tokens
|
||||
break
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue