mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 07:14:20 +00:00
fixing safety inference and safety adapter for new API spec. Pinned the llama_models version to 0.0.24 as the latest version 0.0.35 has the model descriptor name changed. I was getting the missing package error during runtime as well, hence added the dependency to requirements.txt
This commit is contained in:
parent
53070e34a3
commit
9bb0c8f4fc
4 changed files with 33 additions and 26 deletions
|
@ -4,17 +4,12 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_models.schema_utils import json_schema_type
|
||||
|
||||
from llama_stack.distribution.request_headers import annotate_header
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class TogetherHeaderExtractor(BaseModel):
|
||||
api_key: annotate_header(
|
||||
"X-LlamaStack-Together-ApiKey", str, "The API Key for the request"
|
||||
)
|
||||
together_api_key: str
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
|
@ -18,6 +18,7 @@ from llama_stack.apis.inference import * # noqa: F403
|
|||
from llama_stack.providers.utils.inference.augment_messages import (
|
||||
augment_messages_for_tools,
|
||||
)
|
||||
from llama_stack.distribution.request_headers import get_request_provider_data
|
||||
|
||||
from .config import TogetherImplConfig
|
||||
|
||||
|
@ -97,6 +98,16 @@ class TogetherInferenceAdapter(Inference):
|
|||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
|
||||
together_api_key = None
|
||||
provider_data = get_request_provider_data()
|
||||
if provider_data is None or not provider_data.together_api_key:
|
||||
raise ValueError(
|
||||
'Pass Together API Key in the header X-LlamaStack-ProviderData as { "together_api_key": <your api key>}'
|
||||
)
|
||||
together_api_key = provider_data.together_api_key
|
||||
|
||||
client = Together(api_key=together_api_key)
|
||||
# wrapper request to make it easier to pass around (internal only, not exposed to API)
|
||||
request = ChatCompletionRequest(
|
||||
model=model,
|
||||
|
@ -116,7 +127,7 @@ class TogetherInferenceAdapter(Inference):
|
|||
|
||||
if not request.stream:
|
||||
# TODO: might need to add back an async here
|
||||
r = self.client.chat.completions.create(
|
||||
r = client.chat.completions.create(
|
||||
model=together_model,
|
||||
messages=self._messages_to_together_messages(messages),
|
||||
stream=False,
|
||||
|
@ -151,7 +162,7 @@ class TogetherInferenceAdapter(Inference):
|
|||
ipython = False
|
||||
stop_reason = None
|
||||
|
||||
for chunk in self.client.chat.completions.create(
|
||||
for chunk in client.chat.completions.create(
|
||||
model=together_model,
|
||||
messages=self._messages_to_together_messages(messages),
|
||||
stream=True,
|
||||
|
|
|
@ -6,9 +6,16 @@
|
|||
|
||||
from together import Together
|
||||
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_stack.apis.safety import (
|
||||
RunShieldResponse,
|
||||
Safety,
|
||||
SafetyViolation,
|
||||
ViolationLevel,
|
||||
)
|
||||
from llama_stack.distribution.request_headers import get_request_provider_data
|
||||
|
||||
from .config import TogetherProviderDataValidator, TogetherSafetyConfig
|
||||
from .config import TogetherSafetyConfig
|
||||
|
||||
|
||||
class TogetherSafetyImpl(Safety):
|
||||
|
@ -24,21 +31,13 @@ class TogetherSafetyImpl(Safety):
|
|||
if shield_type != "llama_guard":
|
||||
raise ValueError(f"shield type {shield_type} is not supported")
|
||||
|
||||
provider_data = get_request_provider_data()
|
||||
|
||||
together_api_key = None
|
||||
if provider_data is not None:
|
||||
if not isinstance(provider_data, TogetherProviderDataValidator):
|
||||
provider_data = get_request_provider_data()
|
||||
if provider_data is None or not provider_data.together_api_key:
|
||||
raise ValueError(
|
||||
'Pass Together API Key in the header X-LlamaStack-ProviderData as { "together_api_key": <your api key>}'
|
||||
)
|
||||
|
||||
together_api_key = provider_data.together_api_key
|
||||
if not together_api_key:
|
||||
together_api_key = self.config.api_key
|
||||
|
||||
if not together_api_key:
|
||||
raise ValueError("The API key must be provider in the header or config")
|
||||
|
||||
# messages can have role assistant or user
|
||||
api_messages = []
|
||||
|
@ -62,7 +61,9 @@ async def get_safety_response(
|
|||
|
||||
response_text = response.choices[0].message.content
|
||||
if response_text == "safe":
|
||||
return None
|
||||
return SafetyViolation(
|
||||
violation_level=ViolationLevel.INFO, user_message="safe", metadata={}
|
||||
)
|
||||
|
||||
parts = response_text.split("\n")
|
||||
if len(parts) != 2:
|
||||
|
|
|
@ -91,7 +91,7 @@ def available_providers() -> List[ProviderSpec]:
|
|||
],
|
||||
module="llama_stack.providers.adapters.inference.together",
|
||||
config_class="llama_stack.providers.adapters.inference.together.TogetherImplConfig",
|
||||
header_extractor_class="llama_stack.providers.adapters.inference.together.TogetherHeaderExtractor",
|
||||
provider_data_validator="llama_stack.providers.adapters.safety.together.TogetherProviderDataValidator",
|
||||
),
|
||||
),
|
||||
]
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue