mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-09 11:20:58 +00:00
Make Safety test work, other cleanup
This commit is contained in:
parent
ba1f294cc6
commit
fcd22b6baa
16 changed files with 229 additions and 123 deletions
|
|
@ -109,3 +109,10 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
"stream": request.stream,
|
||||
**get_sampling_options(request),
|
||||
}
|
||||
|
||||
async def embeddings(
|
||||
self,
|
||||
model: str,
|
||||
contents: List[InterleavedTextMedia],
|
||||
) -> EmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
|
|
|
|||
|
|
@ -15,7 +15,8 @@ from llama_models.llama3.api.tokenizer import Tokenizer
|
|||
from ollama import AsyncClient
|
||||
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.apis.models import * # noqa: F403
|
||||
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
||||
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
get_sampling_options,
|
||||
OpenAICompatCompletionChoice,
|
||||
|
|
@ -36,7 +37,7 @@ OLLAMA_SUPPORTED_MODELS = {
|
|||
}
|
||||
|
||||
|
||||
class OllamaInferenceAdapter(Inference, Models):
|
||||
class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||
def __init__(self, url: str) -> None:
|
||||
self.url = url
|
||||
self.formatter = ChatFormat(Tokenizer.get_instance())
|
||||
|
|
@ -58,26 +59,30 @@ class OllamaInferenceAdapter(Inference, Models):
|
|||
pass
|
||||
|
||||
async def register_model(self, model: ModelDef) -> None:
|
||||
if model.identifier not in OLLAMA_SUPPORTED_MODELS:
|
||||
raise ValueError(
|
||||
f"Unsupported model {model.identifier}. Supported models: {OLLAMA_SUPPORTED_MODELS.keys()}"
|
||||
raise ValueError("Dynamic model registration is not supported")
|
||||
|
||||
async def list_models(self) -> List[ModelDef]:
|
||||
ollama_to_llama = {v: k for k, v in OLLAMA_SUPPORTED_MODELS.items()}
|
||||
|
||||
ret = []
|
||||
res = await self.client.ps()
|
||||
for r in res["models"]:
|
||||
if r["model"] not in ollama_to_llama:
|
||||
print(f"Ollama is running a model unknown to Llama Stack: {r['model']}")
|
||||
continue
|
||||
|
||||
llama_model = ollama_to_llama[r["model"]]
|
||||
ret.append(
|
||||
ModelDef(
|
||||
identifier=llama_model,
|
||||
llama_model=llama_model,
|
||||
metadata={
|
||||
"ollama_model": r["model"],
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
ollama_model = OLLAMA_SUPPORTED_MODELS[model.identifier]
|
||||
res = await self.client.ps()
|
||||
need_model_pull = True
|
||||
for r in res["models"]:
|
||||
if ollama_model == r["model"]:
|
||||
need_model_pull = False
|
||||
break
|
||||
|
||||
print(f"Ollama model `{ollama_model}` needs pull -> {need_model_pull}")
|
||||
if need_model_pull:
|
||||
print(f"Pulling model: {ollama_model}")
|
||||
status = await self.client.pull(ollama_model)
|
||||
assert (
|
||||
status["status"] == "success"
|
||||
), f"Failed to pull model {self.model} in ollama"
|
||||
return ret
|
||||
|
||||
def completion(
|
||||
self,
|
||||
|
|
@ -161,3 +166,10 @@ class OllamaInferenceAdapter(Inference, Models):
|
|||
request, stream, self.formatter
|
||||
):
|
||||
yield chunk
|
||||
|
||||
async def embeddings(
|
||||
self,
|
||||
model: str,
|
||||
contents: List[InterleavedTextMedia],
|
||||
) -> EmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
|
|
|
|||
|
|
@ -63,19 +63,6 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
|
|||
)
|
||||
]
|
||||
|
||||
async def get_model(self, identifier: str) -> Optional[ModelDef]:
|
||||
model = self.huggingface_repo_to_llama_model_id.get(self.model_id)
|
||||
if model != identifier:
|
||||
return None
|
||||
|
||||
return ModelDef(
|
||||
identifier=model,
|
||||
llama_model=model,
|
||||
metadata={
|
||||
"huggingface_repo": self.model_id,
|
||||
},
|
||||
)
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ from together import Together
|
|||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_stack.apis.safety import * # noqa: F403
|
||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
|
||||
|
||||
from .config import TogetherSafetyConfig
|
||||
|
||||
|
|
@ -19,7 +20,7 @@ TOGETHER_SHIELD_MODEL_MAP = {
|
|||
}
|
||||
|
||||
|
||||
class TogetherSafetyImpl(Safety, NeedsRequestProviderData):
|
||||
class TogetherSafetyImpl(Safety, NeedsRequestProviderData, ShieldsProtocolPrivate):
|
||||
def __init__(self, config: TogetherSafetyConfig) -> None:
|
||||
self.config = config
|
||||
|
||||
|
|
@ -30,8 +31,16 @@ class TogetherSafetyImpl(Safety, NeedsRequestProviderData):
|
|||
pass
|
||||
|
||||
async def register_shield(self, shield: ShieldDef) -> None:
|
||||
if shield.type != ShieldType.llama_guard.value:
|
||||
raise ValueError(f"Unsupported safety shield type: {shield.type}")
|
||||
raise ValueError("Registering dynamic shields is not supported")
|
||||
|
||||
async def list_shields(self) -> List[ShieldDef]:
|
||||
return [
|
||||
ShieldDef(
|
||||
identifier=ShieldType.llama_guard.value,
|
||||
type=ShieldType.llama_guard.value,
|
||||
params={},
|
||||
)
|
||||
]
|
||||
|
||||
async def run_shield(
|
||||
self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None
|
||||
|
|
@ -86,7 +95,6 @@ async def get_safety_response(
|
|||
if parts[0] == "unsafe":
|
||||
return SafetyViolation(
|
||||
violation_level=ViolationLevel.ERROR,
|
||||
user_message="unsafe",
|
||||
metadata={"violation_type": parts[1]},
|
||||
)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue