fixing safety inference and safety adapter for new API spec. Pinned t… (#105)

* fixing safety inference and safety adapter for new API spec. Pinned the llama_models version to 0.0.24 as the latest version 0.0.35 has the model descriptor name changed. I was getting the missing package error during runtime as well, hence added the dependency to requirements.txt

* support Llama 3.2 models in Together inference adapter and cleanup Together safety adapter

* fixing model names

* adding vision guard to Together safety
This commit is contained in:
Yogish Baliga 2024-09-28 15:45:38 -07:00 committed by GitHub
parent 0a3999a9a4
commit 940968ee3f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 68 additions and 40 deletions

View file

@ -18,13 +18,17 @@ from llama_stack.apis.inference import * # noqa: F403
from llama_stack.providers.utils.inference.augment_messages import (
augment_messages_for_tools,
)
from llama_stack.distribution.request_headers import get_request_provider_data
from .config import TogetherImplConfig
TOGETHER_SUPPORTED_MODELS = {
"Llama3.1-8B-Instruct": "meta-llama/Llama-3.1-8B-Instruct-Turbo",
"Llama3.1-70B-Instruct": "meta-llama/Llama-3.1-70B-Instruct-Turbo",
"Llama3.1-405B-Instruct": "meta-llama/Llama-3.1-405B-Instruct-Turbo",
"Llama3.1-8B-Instruct": "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo",
"Llama3.1-70B-Instruct": "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
"Llama3.1-405B-Instruct": "meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo",
"Llama3.2-3B-Instruct": "meta-llama/Llama-3.2-3B-Instruct-Turbo",
"Llama3.2-11B-Vision-Instruct": "meta-llama/Llama-3.2-11B-Vision-Instruct-Turbo",
"Llama3.2-90B-Vision-Instruct": "meta-llama/Llama-3.2-90B-Vision-Instruct-Turbo",
}
@ -97,6 +101,16 @@ class TogetherInferenceAdapter(Inference):
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
together_api_key = None
provider_data = get_request_provider_data()
if provider_data is None or not provider_data.together_api_key:
raise ValueError(
'Pass Together API Key in the header X-LlamaStack-ProviderData as { "together_api_key": <your api key>}'
)
together_api_key = provider_data.together_api_key
client = Together(api_key=together_api_key)
# wrapper request to make it easier to pass around (internal only, not exposed to API)
request = ChatCompletionRequest(
model=model,
@ -116,7 +130,7 @@ class TogetherInferenceAdapter(Inference):
if not request.stream:
# TODO: might need to add back an async here
r = self.client.chat.completions.create(
r = client.chat.completions.create(
model=together_model,
messages=self._messages_to_together_messages(messages),
stream=False,
@ -151,7 +165,7 @@ class TogetherInferenceAdapter(Inference):
ipython = False
stop_reason = None
for chunk in self.client.chat.completions.create(
for chunk in client.chat.completions.create(
model=together_model,
messages=self._messages_to_together_messages(messages),
stream=True,