mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-12 13:00:39 +00:00
Merge branch 'agents-unify-tools' into agents-unify-tools-2
This commit is contained in:
commit
7b0ff5718e
48 changed files with 1850 additions and 2501 deletions
2205
docs/_static/llama-stack-spec.html
vendored
2205
docs/_static/llama-stack-spec.html
vendored
File diff suppressed because it is too large
Load diff
1450
docs/_static/llama-stack-spec.yaml
vendored
1450
docs/_static/llama-stack-spec.yaml
vendored
File diff suppressed because it is too large
Load diff
|
@ -1017,14 +1017,14 @@
|
|||
" \"content\": SYSTEM_PROMPT_TEMPLATE.format(subject=subset),\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"client.eval_tasks.register(\n",
|
||||
" eval_task_id=\"meta-reference::mmmu\",\n",
|
||||
"client.benchmarks.register(\n",
|
||||
" benchmark_id=\"meta-reference::mmmu\",\n",
|
||||
" dataset_id=f\"mmmu-{subset}-{split}\",\n",
|
||||
" scoring_functions=[\"basic::regex_parser_multiple_choice_answer\"],\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"response = client.eval.evaluate_rows(\n",
|
||||
" task_id=\"meta-reference::mmmu\",\n",
|
||||
"response = client.eval.evaluate_rows_alpha(\n",
|
||||
" benchmark_id=\"meta-reference::mmmu\",\n",
|
||||
" input_rows=eval_rows,\n",
|
||||
" scoring_functions=[\"basic::regex_parser_multiple_choice_answer\"],\n",
|
||||
" task_config={\n",
|
||||
|
@ -1196,14 +1196,14 @@
|
|||
" provider_id=\"together\",\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"client.eval_tasks.register(\n",
|
||||
" eval_task_id=\"meta-reference::simpleqa\",\n",
|
||||
"client.benchmarks.register(\n",
|
||||
" benchmark_id=\"meta-reference::simpleqa\",\n",
|
||||
" dataset_id=simpleqa_dataset_id,\n",
|
||||
" scoring_functions=[\"llm-as-judge::405b-simpleqa\"],\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"response = client.eval.evaluate_rows(\n",
|
||||
" task_id=\"meta-reference::simpleqa\",\n",
|
||||
"response = client.eval.evaluate_rows_alpha(\n",
|
||||
" benchmark_id=\"meta-reference::simpleqa\",\n",
|
||||
" input_rows=eval_rows.rows,\n",
|
||||
" scoring_functions=[\"llm-as-judge::405b-simpleqa\"],\n",
|
||||
" task_config={\n",
|
||||
|
@ -1351,8 +1351,8 @@
|
|||
" \"enable_session_persistence\": False,\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"response = client.eval.evaluate_rows(\n",
|
||||
" task_id=\"meta-reference::simpleqa\",\n",
|
||||
"response = client.eval.evaluate_rows_alpha(\n",
|
||||
" benchmark_id=\"meta-reference::simpleqa\",\n",
|
||||
" input_rows=eval_rows.rows,\n",
|
||||
" scoring_functions=[\"llm-as-judge::405b-simpleqa\"],\n",
|
||||
" task_config={\n",
|
||||
|
|
|
@ -150,7 +150,14 @@ def _get_endpoint_functions(
|
|||
|
||||
print(f"Processing {colored(func_name, 'white')}...")
|
||||
operation_name = func_name
|
||||
if operation_name.startswith("get_") or operation_name.endswith("/get"):
|
||||
|
||||
if webmethod.method == "GET":
|
||||
prefix = "get"
|
||||
elif webmethod.method == "DELETE":
|
||||
prefix = "delete"
|
||||
elif webmethod.method == "POST":
|
||||
prefix = "post"
|
||||
elif operation_name.startswith("get_") or operation_name.endswith("/get"):
|
||||
prefix = "get"
|
||||
elif (
|
||||
operation_name.startswith("delete_")
|
||||
|
@ -160,13 +167,8 @@ def _get_endpoint_functions(
|
|||
):
|
||||
prefix = "delete"
|
||||
else:
|
||||
if webmethod.method == "GET":
|
||||
prefix = "get"
|
||||
elif webmethod.method == "DELETE":
|
||||
prefix = "delete"
|
||||
else:
|
||||
# by default everything else is a POST
|
||||
prefix = "post"
|
||||
# by default everything else is a POST
|
||||
prefix = "post"
|
||||
|
||||
yield prefix, operation_name, func_name, func_ref
|
||||
|
||||
|
|
|
@ -122,3 +122,20 @@ response = agent.create_turn(
|
|||
session_id=session_id,
|
||||
)
|
||||
```
|
||||
|
||||
### Unregistering Vector DBs
|
||||
|
||||
If you need to clean up and unregister vector databases, you can do so as follows:
|
||||
|
||||
```python
|
||||
# Unregister a specified vector database
|
||||
vector_db_id = "my_vector_db_id"
|
||||
print(f"Unregistering vector database: {vector_db_id}")
|
||||
client.vector_dbs.unregister(vector_db_id=vector_db_id)
|
||||
|
||||
|
||||
# Unregister all vector databases
|
||||
for vector_db_id in client.vector_dbs.list():
|
||||
print(f"Unregistering vector database: {vector_db_id.identifier}")
|
||||
client.vector_dbs.unregister(vector_db_id=vector_db_id.identifier)
|
||||
```
|
||||
|
|
|
@ -60,6 +60,11 @@ Features:
|
|||
- Disabled dangerous system operations
|
||||
- Configurable execution timeouts
|
||||
|
||||
> ⚠️ Important: The code interpreter tool can operate in a controlled enviroment locally or on Podman containers. To ensure proper functionality in containerised environments:
|
||||
> - The container requires privileged access (e.g., --privileged).
|
||||
> - Users without sufficient permissions may encounter permission errors. (`bwrap: Can't mount devpts on /newroot/dev/pts: Permission denied`)
|
||||
> - 🔒 Security Warning: Privileged mode grants elevated access and bypasses security restrictions. Use only in local, isolated, or controlled environments.
|
||||
|
||||
#### WolframAlpha
|
||||
|
||||
The WolframAlpha tool provides access to computational knowledge through the WolframAlpha API.
|
||||
|
@ -103,7 +108,7 @@ Features:
|
|||
|
||||
MCP tools are special tools that can interact with llama stack over model context protocol. These tools are dynamically discovered from an MCP endpoint and can be used to extend the agent's capabilities.
|
||||
|
||||
Refer to https://github.com/modelcontextprotocol/server for available MCP servers.
|
||||
Refer to [https://github.com/modelcontextprotocol/servers](https://github.com/modelcontextprotocol/servers) for available MCP servers.
|
||||
|
||||
```python
|
||||
# Register MCP tools
|
||||
|
@ -191,3 +196,36 @@ all_tools = client.tools.list_tools()
|
|||
# List tools in a specific group
|
||||
group_tools = client.tools.list_tools(toolgroup_id="search_tools")
|
||||
```
|
||||
|
||||
## Simple Example: Using an Agent with the Code-Interpreter Tool
|
||||
|
||||
```python
|
||||
from llama_stack_client.lib.agents.agent import Agent
|
||||
from llama_stack_client.types.agent_create_params import AgentConfig
|
||||
|
||||
# Configure the AI agent with necessary parameters
|
||||
agent_config = AgentConfig(
|
||||
name="code-interpreter",
|
||||
description="A code interpreter agent for executing Python code snippets",
|
||||
instructions="""
|
||||
You are a highly reliable, concise, and precise assistant.
|
||||
Always show the generated code, never generate your own code, and never anticipate results.
|
||||
""",
|
||||
model="meta-llama/Llama-3.2-3B-Instruct",
|
||||
toolgroups=["builtin::code_interpreter"],
|
||||
max_infer_iters=5,
|
||||
enable_session_persistence=False,
|
||||
)
|
||||
|
||||
# Instantiate the AI agent with the given configuration
|
||||
agent = Agent(client, agent_config)
|
||||
|
||||
# Start a session
|
||||
session_id = agent.create_session("tool_session")
|
||||
|
||||
# Send a query to the AI agent for code execution
|
||||
response = agent.create_turn(
|
||||
messages=[{"role": "user", "content": "Run this code: print(3 ** 4 - 5 * 2)"}],
|
||||
session_id=session_id,
|
||||
)
|
||||
```
|
||||
|
|
|
@ -130,7 +130,7 @@ llama stack run ./run-with-safety.yaml \
|
|||
### (Optional) Update Model Serving Configuration
|
||||
|
||||
```{note}
|
||||
Please check the [model_aliases](https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/inference/ollama/ollama.py#L45) for the supported Ollama models.
|
||||
Please check the [model_entries](https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/inference/ollama/ollama.py#L45) for the supported Ollama models.
|
||||
```
|
||||
|
||||
To serve a new model with `ollama`
|
||||
|
|
|
@ -64,23 +64,3 @@ class Benchmarks(Protocol):
|
|||
provider_id: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> None: ...
|
||||
|
||||
@webmethod(route="/eval-tasks", method="GET")
|
||||
async def DEPRECATED_list_eval_tasks(self) -> ListBenchmarksResponse: ...
|
||||
|
||||
@webmethod(route="/eval-tasks/{eval_task_id}", method="GET")
|
||||
async def DEPRECATED_get_eval_task(
|
||||
self,
|
||||
eval_task_id: str,
|
||||
) -> Optional[Benchmark]: ...
|
||||
|
||||
@webmethod(route="/eval-tasks", method="POST")
|
||||
async def DEPRECATED_register_eval_task(
|
||||
self,
|
||||
eval_task_id: str,
|
||||
dataset_id: str,
|
||||
scoring_functions: List[str],
|
||||
provider_benchmark_id: Optional[str] = None,
|
||||
provider_id: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> None: ...
|
||||
|
|
|
@ -39,7 +39,6 @@ EvalCandidate = register_schema(
|
|||
|
||||
@json_schema_type
|
||||
class BenchmarkConfig(BaseModel):
|
||||
type: Literal["benchmark"] = "benchmark"
|
||||
eval_candidate: EvalCandidate
|
||||
scoring_params: Dict[str, ScoringFnParams] = Field(
|
||||
description="Map between scoring function id and parameters for each scoring function you want to run",
|
||||
|
@ -84,28 +83,3 @@ class Eval(Protocol):
|
|||
|
||||
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}/result", method="GET")
|
||||
async def job_result(self, benchmark_id: str, job_id: str) -> EvaluateResponse: ...
|
||||
|
||||
@webmethod(route="/eval/tasks/{task_id}/jobs", method="POST")
|
||||
async def DEPRECATED_run_eval(
|
||||
self,
|
||||
task_id: str,
|
||||
task_config: BenchmarkConfig,
|
||||
) -> Job: ...
|
||||
|
||||
@webmethod(route="/eval/tasks/{task_id}/evaluations", method="POST")
|
||||
async def DEPRECATED_evaluate_rows(
|
||||
self,
|
||||
task_id: str,
|
||||
input_rows: List[Dict[str, Any]],
|
||||
scoring_functions: List[str],
|
||||
task_config: BenchmarkConfig,
|
||||
) -> EvaluateResponse: ...
|
||||
|
||||
@webmethod(route="/eval/tasks/{task_id}/jobs/{job_id}", method="GET")
|
||||
async def DEPRECATED_job_status(self, task_id: str, job_id: str) -> Optional[JobStatus]: ...
|
||||
|
||||
@webmethod(route="/eval/tasks/{task_id}/jobs/{job_id}", method="DELETE")
|
||||
async def DEPRECATED_job_cancel(self, task_id: str, job_id: str) -> None: ...
|
||||
|
||||
@webmethod(route="/eval/tasks/{task_id}/jobs/{job_id}/result", method="GET")
|
||||
async def DEPRECATED_job_result(self, task_id: str, job_id: str) -> EvaluateResponse: ...
|
||||
|
|
|
@ -216,7 +216,7 @@ class Telemetry(Protocol):
|
|||
@webmethod(route="/telemetry/events", method="POST")
|
||||
async def log_event(self, event: Event, ttl_seconds: int = DEFAULT_TTL_DAYS * 86400) -> None: ...
|
||||
|
||||
@webmethod(route="/telemetry/traces", method="GET")
|
||||
@webmethod(route="/telemetry/traces", method="POST")
|
||||
async def query_traces(
|
||||
self,
|
||||
attribute_filters: Optional[List[QueryCondition]] = None,
|
||||
|
@ -231,7 +231,7 @@ class Telemetry(Protocol):
|
|||
@webmethod(route="/telemetry/traces/{trace_id:path}/spans/{span_id:path}", method="GET")
|
||||
async def get_span(self, trace_id: str, span_id: str) -> Span: ...
|
||||
|
||||
@webmethod(route="/telemetry/spans/{span_id:path}/tree", method="GET")
|
||||
@webmethod(route="/telemetry/spans/{span_id:path}/tree", method="POST")
|
||||
async def get_span_tree(
|
||||
self,
|
||||
span_id: str,
|
||||
|
@ -239,7 +239,7 @@ class Telemetry(Protocol):
|
|||
max_depth: Optional[int] = None,
|
||||
) -> QuerySpanTreeResponse: ...
|
||||
|
||||
@webmethod(route="/telemetry/spans", method="GET")
|
||||
@webmethod(route="/telemetry/spans", method="POST")
|
||||
async def query_spans(
|
||||
self,
|
||||
attribute_filters: List[QueryCondition],
|
||||
|
|
|
@ -411,48 +411,6 @@ class EvalRouter(Eval):
|
|||
job_id,
|
||||
)
|
||||
|
||||
async def DEPRECATED_run_eval(
|
||||
self,
|
||||
task_id: str,
|
||||
task_config: BenchmarkConfig,
|
||||
) -> Job:
|
||||
return await self.run_eval(benchmark_id=task_id, task_config=task_config)
|
||||
|
||||
async def DEPRECATED_evaluate_rows(
|
||||
self,
|
||||
task_id: str,
|
||||
input_rows: List[Dict[str, Any]],
|
||||
scoring_functions: List[str],
|
||||
task_config: BenchmarkConfig,
|
||||
) -> EvaluateResponse:
|
||||
return await self.evaluate_rows(
|
||||
benchmark_id=task_id,
|
||||
input_rows=input_rows,
|
||||
scoring_functions=scoring_functions,
|
||||
task_config=task_config,
|
||||
)
|
||||
|
||||
async def DEPRECATED_job_status(
|
||||
self,
|
||||
task_id: str,
|
||||
job_id: str,
|
||||
) -> Optional[JobStatus]:
|
||||
return await self.job_status(benchmark_id=task_id, job_id=job_id)
|
||||
|
||||
async def DEPRECATED_job_cancel(
|
||||
self,
|
||||
task_id: str,
|
||||
job_id: str,
|
||||
) -> None:
|
||||
return await self.job_cancel(benchmark_id=task_id, job_id=job_id)
|
||||
|
||||
async def DEPRECATED_job_result(
|
||||
self,
|
||||
task_id: str,
|
||||
job_id: str,
|
||||
) -> EvaluateResponse:
|
||||
return await self.job_result(benchmark_id=task_id, job_id=job_id)
|
||||
|
||||
|
||||
class ToolRuntimeRouter(ToolRuntime):
|
||||
class RagToolImpl(RAGToolRuntime):
|
||||
|
|
|
@ -468,35 +468,6 @@ class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks):
|
|||
)
|
||||
await self.register_object(benchmark)
|
||||
|
||||
async def DEPRECATED_list_eval_tasks(self) -> ListBenchmarksResponse:
|
||||
logger.warning("DEPRECATED: Use /eval/benchmarks instead")
|
||||
return await self.list_benchmarks()
|
||||
|
||||
async def DEPRECATED_get_eval_task(
|
||||
self,
|
||||
eval_task_id: str,
|
||||
) -> Optional[Benchmark]:
|
||||
logger.warning("DEPRECATED: Use /eval/benchmarks instead")
|
||||
return await self.get_benchmark(eval_task_id)
|
||||
|
||||
async def DEPRECATED_register_eval_task(
|
||||
self,
|
||||
eval_task_id: str,
|
||||
dataset_id: str,
|
||||
scoring_functions: List[str],
|
||||
provider_benchmark_id: Optional[str] = None,
|
||||
provider_id: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
logger.warning("DEPRECATED: Use /eval/benchmarks instead")
|
||||
return await self.register_benchmark(
|
||||
benchmark_id=eval_task_id,
|
||||
dataset_id=dataset_id,
|
||||
scoring_functions=scoring_functions,
|
||||
metadata=metadata,
|
||||
provider_benchmark_id=provider_benchmark_id,
|
||||
)
|
||||
|
||||
|
||||
class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
||||
async def list_tools(self, toolgroup_id: Optional[str] = None) -> ListToolsResponse:
|
||||
|
|
|
@ -481,6 +481,8 @@ def main():
|
|||
def extract_path_params(route: str) -> List[str]:
|
||||
segments = route.split("/")
|
||||
params = [seg[1:-1] for seg in segments if seg.startswith("{") and seg.endswith("}")]
|
||||
# to handle path params like {param:path}
|
||||
params = [param.split(":")[0] for param in params]
|
||||
return params
|
||||
|
||||
|
||||
|
|
|
@ -234,45 +234,3 @@ class MetaReferenceEvalImpl(
|
|||
raise ValueError(f"Job is not completed, Status: {status.value}")
|
||||
|
||||
return self.jobs[job_id]
|
||||
|
||||
async def DEPRECATED_run_eval(
|
||||
self,
|
||||
task_id: str,
|
||||
task_config: BenchmarkConfig,
|
||||
) -> Job:
|
||||
return await self.run_eval(benchmark_id=task_id, task_config=task_config)
|
||||
|
||||
async def DEPRECATED_evaluate_rows(
|
||||
self,
|
||||
task_id: str,
|
||||
input_rows: List[Dict[str, Any]],
|
||||
scoring_functions: List[str],
|
||||
task_config: BenchmarkConfig,
|
||||
) -> EvaluateResponse:
|
||||
return await self.evaluate_rows(
|
||||
benchmark_id=task_id,
|
||||
input_rows=input_rows,
|
||||
scoring_functions=scoring_functions,
|
||||
task_config=task_config,
|
||||
)
|
||||
|
||||
async def DEPRECATED_job_status(
|
||||
self,
|
||||
task_id: str,
|
||||
job_id: str,
|
||||
) -> Optional[JobStatus]:
|
||||
return await self.job_status(benchmark_id=task_id, job_id=job_id)
|
||||
|
||||
async def DEPRECATED_job_cancel(
|
||||
self,
|
||||
task_id: str,
|
||||
job_id: str,
|
||||
) -> None:
|
||||
return await self.job_cancel(benchmark_id=task_id, job_id=job_id)
|
||||
|
||||
async def DEPRECATED_job_result(
|
||||
self,
|
||||
task_id: str,
|
||||
job_id: str,
|
||||
) -> EvaluateResponse:
|
||||
return await self.job_result(benchmark_id=task_id, job_id=job_id)
|
||||
|
|
|
@ -46,7 +46,7 @@ from llama_stack.providers.utils.inference.embedding_mixin import (
|
|||
)
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
ModelRegistryHelper,
|
||||
build_hf_repo_model_alias,
|
||||
build_hf_repo_model_entry,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
augment_content_with_response_format_prompt,
|
||||
|
@ -116,7 +116,7 @@ class MetaReferenceInferenceImpl(
|
|||
|
||||
self.model_registry_helper = ModelRegistryHelper(
|
||||
[
|
||||
build_hf_repo_model_alias(
|
||||
build_hf_repo_model_entry(
|
||||
llama_model.descriptor(),
|
||||
llama_model.core_model_id.value,
|
||||
)
|
||||
|
|
|
@ -43,12 +43,12 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
interleaved_content_as_str,
|
||||
)
|
||||
|
||||
from .models import MODEL_ALIASES
|
||||
from .models import MODEL_ENTRIES
|
||||
|
||||
|
||||
class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
|
||||
def __init__(self, config: BedrockConfig) -> None:
|
||||
ModelRegistryHelper.__init__(self, MODEL_ALIASES)
|
||||
ModelRegistryHelper.__init__(self, MODEL_ENTRIES)
|
||||
self._config = config
|
||||
|
||||
self._client = create_bedrock_client(config)
|
||||
|
|
|
@ -6,19 +6,19 @@
|
|||
|
||||
from llama_stack.models.llama.datatypes import CoreModelId
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
build_hf_repo_model_alias,
|
||||
build_hf_repo_model_entry,
|
||||
)
|
||||
|
||||
MODEL_ALIASES = [
|
||||
build_hf_repo_model_alias(
|
||||
MODEL_ENTRIES = [
|
||||
build_hf_repo_model_entry(
|
||||
"meta.llama3-1-8b-instruct-v1:0",
|
||||
CoreModelId.llama3_1_8b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_alias(
|
||||
build_hf_repo_model_entry(
|
||||
"meta.llama3-1-70b-instruct-v1:0",
|
||||
CoreModelId.llama3_1_70b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_alias(
|
||||
build_hf_repo_model_entry(
|
||||
"meta.llama3-1-405b-instruct-v1:0",
|
||||
CoreModelId.llama3_1_405b_instruct.value,
|
||||
),
|
||||
|
|
|
@ -41,14 +41,14 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
)
|
||||
|
||||
from .config import CerebrasImplConfig
|
||||
from .models import model_aliases
|
||||
from .models import model_entries
|
||||
|
||||
|
||||
class CerebrasInferenceAdapter(ModelRegistryHelper, Inference):
|
||||
def __init__(self, config: CerebrasImplConfig) -> None:
|
||||
ModelRegistryHelper.__init__(
|
||||
self,
|
||||
model_aliases=model_aliases,
|
||||
model_entries=model_entries,
|
||||
)
|
||||
self.config = config
|
||||
|
||||
|
|
|
@ -6,15 +6,15 @@
|
|||
|
||||
from llama_stack.models.llama.datatypes import CoreModelId
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
build_hf_repo_model_alias,
|
||||
build_hf_repo_model_entry,
|
||||
)
|
||||
|
||||
model_aliases = [
|
||||
build_hf_repo_model_alias(
|
||||
model_entries = [
|
||||
build_hf_repo_model_entry(
|
||||
"llama3.1-8b",
|
||||
CoreModelId.llama3_1_8b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_alias(
|
||||
build_hf_repo_model_entry(
|
||||
"llama-3.3-70b",
|
||||
CoreModelId.llama3_3_70b_instruct.value,
|
||||
),
|
||||
|
|
|
@ -25,7 +25,7 @@ from llama_stack.apis.inference import (
|
|||
from llama_stack.models.llama.datatypes import CoreModelId
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
ModelRegistryHelper,
|
||||
build_hf_repo_model_alias,
|
||||
build_hf_repo_model_entry,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
get_sampling_options,
|
||||
|
@ -38,12 +38,12 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
|
||||
from .config import DatabricksImplConfig
|
||||
|
||||
model_aliases = [
|
||||
build_hf_repo_model_alias(
|
||||
model_entries = [
|
||||
build_hf_repo_model_entry(
|
||||
"databricks-meta-llama-3-1-70b-instruct",
|
||||
CoreModelId.llama3_1_70b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_alias(
|
||||
build_hf_repo_model_entry(
|
||||
"databricks-meta-llama-3-1-405b-instruct",
|
||||
CoreModelId.llama3_1_405b_instruct.value,
|
||||
),
|
||||
|
@ -52,7 +52,7 @@ model_aliases = [
|
|||
|
||||
class DatabricksInferenceAdapter(ModelRegistryHelper, Inference):
|
||||
def __init__(self, config: DatabricksImplConfig) -> None:
|
||||
ModelRegistryHelper.__init__(self, model_aliases=model_aliases)
|
||||
ModelRegistryHelper.__init__(self, model_entries=model_entries)
|
||||
self.config = config
|
||||
|
||||
async def initialize(self) -> None:
|
||||
|
|
|
@ -47,12 +47,12 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
)
|
||||
|
||||
from .config import FireworksImplConfig
|
||||
from .models import MODEL_ALIASES
|
||||
from .models import MODEL_ENTRIES
|
||||
|
||||
|
||||
class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData):
|
||||
def __init__(self, config: FireworksImplConfig) -> None:
|
||||
ModelRegistryHelper.__init__(self, MODEL_ALIASES)
|
||||
ModelRegistryHelper.__init__(self, MODEL_ENTRIES)
|
||||
self.config = config
|
||||
|
||||
async def initialize(self) -> None:
|
||||
|
|
|
@ -6,47 +6,47 @@
|
|||
|
||||
from llama_stack.models.llama.datatypes import CoreModelId
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
build_hf_repo_model_alias,
|
||||
build_hf_repo_model_entry,
|
||||
)
|
||||
|
||||
MODEL_ALIASES = [
|
||||
build_hf_repo_model_alias(
|
||||
MODEL_ENTRIES = [
|
||||
build_hf_repo_model_entry(
|
||||
"accounts/fireworks/models/llama-v3p1-8b-instruct",
|
||||
CoreModelId.llama3_1_8b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_alias(
|
||||
build_hf_repo_model_entry(
|
||||
"accounts/fireworks/models/llama-v3p1-70b-instruct",
|
||||
CoreModelId.llama3_1_70b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_alias(
|
||||
build_hf_repo_model_entry(
|
||||
"accounts/fireworks/models/llama-v3p1-405b-instruct",
|
||||
CoreModelId.llama3_1_405b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_alias(
|
||||
build_hf_repo_model_entry(
|
||||
"accounts/fireworks/models/llama-v3p2-1b-instruct",
|
||||
CoreModelId.llama3_2_1b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_alias(
|
||||
build_hf_repo_model_entry(
|
||||
"accounts/fireworks/models/llama-v3p2-3b-instruct",
|
||||
CoreModelId.llama3_2_3b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_alias(
|
||||
build_hf_repo_model_entry(
|
||||
"accounts/fireworks/models/llama-v3p2-11b-vision-instruct",
|
||||
CoreModelId.llama3_2_11b_vision_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_alias(
|
||||
build_hf_repo_model_entry(
|
||||
"accounts/fireworks/models/llama-v3p2-90b-vision-instruct",
|
||||
CoreModelId.llama3_2_90b_vision_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_alias(
|
||||
build_hf_repo_model_entry(
|
||||
"accounts/fireworks/models/llama-v3p3-70b-instruct",
|
||||
CoreModelId.llama3_3_70b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_alias(
|
||||
build_hf_repo_model_entry(
|
||||
"accounts/fireworks/models/llama-guard-3-8b",
|
||||
CoreModelId.llama_guard_3_8b.value,
|
||||
),
|
||||
build_hf_repo_model_alias(
|
||||
build_hf_repo_model_entry(
|
||||
"accounts/fireworks/models/llama-guard-3-11b-vision",
|
||||
CoreModelId.llama_guard_3_11b_vision.value,
|
||||
),
|
||||
|
|
|
@ -31,8 +31,8 @@ from llama_stack.models.llama.sku_list import CoreModelId
|
|||
from llama_stack.providers.remote.inference.groq.config import GroqConfig
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
ModelRegistryHelper,
|
||||
build_hf_repo_model_alias,
|
||||
build_model_alias,
|
||||
build_hf_repo_model_entry,
|
||||
build_model_entry,
|
||||
)
|
||||
|
||||
from .groq_utils import (
|
||||
|
@ -41,20 +41,20 @@ from .groq_utils import (
|
|||
convert_chat_completion_response_stream,
|
||||
)
|
||||
|
||||
_MODEL_ALIASES = [
|
||||
build_hf_repo_model_alias(
|
||||
_MODEL_ENTRIES = [
|
||||
build_hf_repo_model_entry(
|
||||
"llama3-8b-8192",
|
||||
CoreModelId.llama3_1_8b_instruct.value,
|
||||
),
|
||||
build_model_alias(
|
||||
build_model_entry(
|
||||
"llama-3.1-8b-instant",
|
||||
CoreModelId.llama3_1_8b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_alias(
|
||||
build_hf_repo_model_entry(
|
||||
"llama3-70b-8192",
|
||||
CoreModelId.llama3_70b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_alias(
|
||||
build_hf_repo_model_entry(
|
||||
"llama-3.3-70b-versatile",
|
||||
CoreModelId.llama3_3_70b_instruct.value,
|
||||
),
|
||||
|
@ -62,7 +62,7 @@ _MODEL_ALIASES = [
|
|||
# Preview models aren't recommended for production use, but we include this one
|
||||
# to pass the test fixture
|
||||
# TODO(aidand): Replace this with a stable model once Groq supports it
|
||||
build_hf_repo_model_alias(
|
||||
build_hf_repo_model_entry(
|
||||
"llama-3.2-3b-preview",
|
||||
CoreModelId.llama3_2_3b_instruct.value,
|
||||
),
|
||||
|
@ -73,7 +73,7 @@ class GroqInferenceAdapter(Inference, ModelRegistryHelper, NeedsRequestProviderD
|
|||
_config: GroqConfig
|
||||
|
||||
def __init__(self, config: GroqConfig):
|
||||
ModelRegistryHelper.__init__(self, model_aliases=_MODEL_ALIASES)
|
||||
ModelRegistryHelper.__init__(self, model_entries=_MODEL_ENTRIES)
|
||||
self._config = config
|
||||
|
||||
def completion(
|
||||
|
|
|
@ -6,43 +6,43 @@
|
|||
|
||||
from llama_stack.models.llama.datatypes import CoreModelId
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
build_hf_repo_model_alias,
|
||||
build_hf_repo_model_entry,
|
||||
)
|
||||
|
||||
_MODEL_ALIASES = [
|
||||
build_hf_repo_model_alias(
|
||||
_MODEL_ENTRIES = [
|
||||
build_hf_repo_model_entry(
|
||||
"meta/llama3-8b-instruct",
|
||||
CoreModelId.llama3_8b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_alias(
|
||||
build_hf_repo_model_entry(
|
||||
"meta/llama3-70b-instruct",
|
||||
CoreModelId.llama3_70b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_alias(
|
||||
build_hf_repo_model_entry(
|
||||
"meta/llama-3.1-8b-instruct",
|
||||
CoreModelId.llama3_1_8b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_alias(
|
||||
build_hf_repo_model_entry(
|
||||
"meta/llama-3.1-70b-instruct",
|
||||
CoreModelId.llama3_1_70b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_alias(
|
||||
build_hf_repo_model_entry(
|
||||
"meta/llama-3.1-405b-instruct",
|
||||
CoreModelId.llama3_1_405b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_alias(
|
||||
build_hf_repo_model_entry(
|
||||
"meta/llama-3.2-1b-instruct",
|
||||
CoreModelId.llama3_2_1b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_alias(
|
||||
build_hf_repo_model_entry(
|
||||
"meta/llama-3.2-3b-instruct",
|
||||
CoreModelId.llama3_2_3b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_alias(
|
||||
build_hf_repo_model_entry(
|
||||
"meta/llama-3.2-11b-vision-instruct",
|
||||
CoreModelId.llama3_2_11b_vision_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_alias(
|
||||
build_hf_repo_model_entry(
|
||||
"meta/llama-3.2-90b-vision-instruct",
|
||||
CoreModelId.llama3_2_90b_vision_instruct.value,
|
||||
),
|
||||
|
|
|
@ -33,7 +33,7 @@ from llama_stack.providers.utils.inference.model_registry import (
|
|||
from llama_stack.providers.utils.inference.prompt_adapter import content_has_media
|
||||
|
||||
from . import NVIDIAConfig
|
||||
from .models import _MODEL_ALIASES
|
||||
from .models import _MODEL_ENTRIES
|
||||
from .openai_utils import (
|
||||
convert_chat_completion_request,
|
||||
convert_completion_request,
|
||||
|
@ -50,7 +50,7 @@ logger = logging.getLogger(__name__)
|
|||
class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
||||
def __init__(self, config: NVIDIAConfig) -> None:
|
||||
# TODO(mf): filter by available models
|
||||
ModelRegistryHelper.__init__(self, model_aliases=_MODEL_ALIASES)
|
||||
ModelRegistryHelper.__init__(self, model_entries=_MODEL_ENTRIES)
|
||||
|
||||
logger.info(f"Initializing NVIDIAInferenceAdapter({config.url})...")
|
||||
|
||||
|
|
|
@ -4,6 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
from typing import Any, Dict
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
@ -12,7 +13,7 @@ DEFAULT_OLLAMA_URL = "http://localhost:11434"
|
|||
|
||||
|
||||
class OllamaImplConfig(BaseModel):
|
||||
url: str = DEFAULT_OLLAMA_URL
|
||||
url: str = os.getenv("OLLAMA_URL", DEFAULT_OLLAMA_URL)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, url: str = "${env.OLLAMA_URL:http://localhost:11434}", **kwargs) -> Dict[str, Any]:
|
||||
|
|
|
@ -35,8 +35,8 @@ from llama_stack.models.llama.datatypes import CoreModelId
|
|||
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
ModelRegistryHelper,
|
||||
build_hf_repo_model_alias,
|
||||
build_model_alias,
|
||||
build_hf_repo_model_entry,
|
||||
build_model_entry,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
OpenAICompatCompletionChoice,
|
||||
|
@ -58,74 +58,74 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
model_aliases = [
|
||||
build_hf_repo_model_alias(
|
||||
model_entries = [
|
||||
build_hf_repo_model_entry(
|
||||
"llama3.1:8b-instruct-fp16",
|
||||
CoreModelId.llama3_1_8b_instruct.value,
|
||||
),
|
||||
build_model_alias(
|
||||
build_model_entry(
|
||||
"llama3.1:8b",
|
||||
CoreModelId.llama3_1_8b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_alias(
|
||||
build_hf_repo_model_entry(
|
||||
"llama3.1:70b-instruct-fp16",
|
||||
CoreModelId.llama3_1_70b_instruct.value,
|
||||
),
|
||||
build_model_alias(
|
||||
build_model_entry(
|
||||
"llama3.1:70b",
|
||||
CoreModelId.llama3_1_70b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_alias(
|
||||
build_hf_repo_model_entry(
|
||||
"llama3.1:405b-instruct-fp16",
|
||||
CoreModelId.llama3_1_405b_instruct.value,
|
||||
),
|
||||
build_model_alias(
|
||||
build_model_entry(
|
||||
"llama3.1:405b",
|
||||
CoreModelId.llama3_1_405b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_alias(
|
||||
build_hf_repo_model_entry(
|
||||
"llama3.2:1b-instruct-fp16",
|
||||
CoreModelId.llama3_2_1b_instruct.value,
|
||||
),
|
||||
build_model_alias(
|
||||
build_model_entry(
|
||||
"llama3.2:1b",
|
||||
CoreModelId.llama3_2_1b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_alias(
|
||||
build_hf_repo_model_entry(
|
||||
"llama3.2:3b-instruct-fp16",
|
||||
CoreModelId.llama3_2_3b_instruct.value,
|
||||
),
|
||||
build_model_alias(
|
||||
build_model_entry(
|
||||
"llama3.2:3b",
|
||||
CoreModelId.llama3_2_3b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_alias(
|
||||
build_hf_repo_model_entry(
|
||||
"llama3.2-vision:11b-instruct-fp16",
|
||||
CoreModelId.llama3_2_11b_vision_instruct.value,
|
||||
),
|
||||
build_model_alias(
|
||||
build_model_entry(
|
||||
"llama3.2-vision:latest",
|
||||
CoreModelId.llama3_2_11b_vision_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_alias(
|
||||
build_hf_repo_model_entry(
|
||||
"llama3.2-vision:90b-instruct-fp16",
|
||||
CoreModelId.llama3_2_90b_vision_instruct.value,
|
||||
),
|
||||
build_model_alias(
|
||||
build_model_entry(
|
||||
"llama3.2-vision:90b",
|
||||
CoreModelId.llama3_2_90b_vision_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_alias(
|
||||
build_hf_repo_model_entry(
|
||||
"llama3.3:70b",
|
||||
CoreModelId.llama3_3_70b_instruct.value,
|
||||
),
|
||||
# The Llama Guard models don't have their full fp16 versions
|
||||
# so we are going to alias their default version to the canonical SKU
|
||||
build_hf_repo_model_alias(
|
||||
build_hf_repo_model_entry(
|
||||
"llama-guard3:8b",
|
||||
CoreModelId.llama_guard_3_8b.value,
|
||||
),
|
||||
build_hf_repo_model_alias(
|
||||
build_hf_repo_model_entry(
|
||||
"llama-guard3:1b",
|
||||
CoreModelId.llama_guard_3_1b.value,
|
||||
),
|
||||
|
@ -134,7 +134,7 @@ model_aliases = [
|
|||
|
||||
class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||
def __init__(self, url: str) -> None:
|
||||
self.register_helper = ModelRegistryHelper(model_aliases)
|
||||
self.register_helper = ModelRegistryHelper(model_entries)
|
||||
self.url = url
|
||||
|
||||
@property
|
||||
|
|
|
@ -6,43 +6,43 @@
|
|||
|
||||
from llama_stack.models.llama.datatypes import CoreModelId
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
build_hf_repo_model_alias,
|
||||
build_hf_repo_model_entry,
|
||||
)
|
||||
|
||||
MODEL_ALIASES = [
|
||||
build_hf_repo_model_alias(
|
||||
MODEL_ENTRIES = [
|
||||
build_hf_repo_model_entry(
|
||||
"Meta-Llama-3.1-8B-Instruct",
|
||||
CoreModelId.llama3_1_8b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_alias(
|
||||
build_hf_repo_model_entry(
|
||||
"Meta-Llama-3.1-70B-Instruct",
|
||||
CoreModelId.llama3_1_70b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_alias(
|
||||
build_hf_repo_model_entry(
|
||||
"Meta-Llama-3.1-405B-Instruct",
|
||||
CoreModelId.llama3_1_405b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_alias(
|
||||
build_hf_repo_model_entry(
|
||||
"Meta-Llama-3.2-1B-Instruct",
|
||||
CoreModelId.llama3_2_1b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_alias(
|
||||
build_hf_repo_model_entry(
|
||||
"Meta-Llama-3.2-3B-Instruct",
|
||||
CoreModelId.llama3_2_3b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_alias(
|
||||
build_hf_repo_model_entry(
|
||||
"Meta-Llama-3.3-70B-Instruct",
|
||||
CoreModelId.llama3_3_70b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_alias(
|
||||
build_hf_repo_model_entry(
|
||||
"Llama-3.2-11B-Vision-Instruct",
|
||||
CoreModelId.llama3_2_11b_vision_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_alias(
|
||||
build_hf_repo_model_entry(
|
||||
"Llama-3.2-90B-Vision-Instruct",
|
||||
CoreModelId.llama3_2_90b_vision_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_alias(
|
||||
build_hf_repo_model_entry(
|
||||
"Meta-Llama-Guard-3-8B",
|
||||
CoreModelId.llama_guard_3_8b.value,
|
||||
),
|
||||
|
|
|
@ -31,12 +31,12 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
)
|
||||
|
||||
from .config import SambaNovaImplConfig
|
||||
from .models import MODEL_ALIASES
|
||||
from .models import MODEL_ENTRIES
|
||||
|
||||
|
||||
class SambaNovaInferenceAdapter(ModelRegistryHelper, Inference):
|
||||
def __init__(self, config: SambaNovaImplConfig) -> None:
|
||||
ModelRegistryHelper.__init__(self, model_aliases=MODEL_ALIASES)
|
||||
ModelRegistryHelper.__init__(self, model_entries=MODEL_ENTRIES)
|
||||
self.config = config
|
||||
|
||||
async def initialize(self) -> None:
|
||||
|
|
|
@ -32,7 +32,7 @@ from llama_stack.models.llama.sku_list import all_registered_models
|
|||
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
ModelRegistryHelper,
|
||||
build_hf_repo_model_alias,
|
||||
build_hf_repo_model_entry,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
OpenAICompatCompletionChoice,
|
||||
|
@ -53,9 +53,9 @@ from .config import InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImpl
|
|||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def build_hf_repo_model_aliases():
|
||||
def build_hf_repo_model_entries():
|
||||
return [
|
||||
build_hf_repo_model_alias(
|
||||
build_hf_repo_model_entry(
|
||||
model.huggingface_repo,
|
||||
model.descriptor(),
|
||||
)
|
||||
|
@ -70,7 +70,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
|
|||
model_id: str
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.register_helper = ModelRegistryHelper(build_hf_repo_model_aliases())
|
||||
self.register_helper = ModelRegistryHelper(build_hf_repo_model_entries())
|
||||
self.huggingface_repo_to_llama_model_id = {
|
||||
model.huggingface_repo: model.descriptor() for model in all_registered_models() if model.huggingface_repo
|
||||
}
|
||||
|
|
|
@ -6,43 +6,43 @@
|
|||
|
||||
from llama_stack.models.llama.datatypes import CoreModelId
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
build_hf_repo_model_alias,
|
||||
build_hf_repo_model_entry,
|
||||
)
|
||||
|
||||
MODEL_ALIASES = [
|
||||
build_hf_repo_model_alias(
|
||||
MODEL_ENTRIES = [
|
||||
build_hf_repo_model_entry(
|
||||
"meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo",
|
||||
CoreModelId.llama3_1_8b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_alias(
|
||||
build_hf_repo_model_entry(
|
||||
"meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
|
||||
CoreModelId.llama3_1_70b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_alias(
|
||||
build_hf_repo_model_entry(
|
||||
"meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo",
|
||||
CoreModelId.llama3_1_405b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_alias(
|
||||
build_hf_repo_model_entry(
|
||||
"meta-llama/Llama-3.2-3B-Instruct-Turbo",
|
||||
CoreModelId.llama3_2_3b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_alias(
|
||||
build_hf_repo_model_entry(
|
||||
"meta-llama/Llama-3.2-11B-Vision-Instruct-Turbo",
|
||||
CoreModelId.llama3_2_11b_vision_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_alias(
|
||||
build_hf_repo_model_entry(
|
||||
"meta-llama/Llama-3.2-90B-Vision-Instruct-Turbo",
|
||||
CoreModelId.llama3_2_90b_vision_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_alias(
|
||||
build_hf_repo_model_entry(
|
||||
"meta-llama/Llama-3.3-70B-Instruct-Turbo",
|
||||
CoreModelId.llama3_3_70b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_alias(
|
||||
build_hf_repo_model_entry(
|
||||
"meta-llama/Meta-Llama-Guard-3-8B",
|
||||
CoreModelId.llama_guard_3_8b.value,
|
||||
),
|
||||
build_hf_repo_model_alias(
|
||||
build_hf_repo_model_entry(
|
||||
"meta-llama/Llama-Guard-3-11B-Vision-Turbo",
|
||||
CoreModelId.llama_guard_3_11b_vision.value,
|
||||
),
|
||||
|
|
|
@ -46,12 +46,12 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
)
|
||||
|
||||
from .config import TogetherImplConfig
|
||||
from .models import MODEL_ALIASES
|
||||
from .models import MODEL_ENTRIES
|
||||
|
||||
|
||||
class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData):
|
||||
def __init__(self, config: TogetherImplConfig) -> None:
|
||||
ModelRegistryHelper.__init__(self, MODEL_ALIASES)
|
||||
ModelRegistryHelper.__init__(self, MODEL_ENTRIES)
|
||||
self.config = config
|
||||
|
||||
async def initialize(self) -> None:
|
||||
|
|
|
@ -38,7 +38,7 @@ from llama_stack.models.llama.sku_list import all_registered_models
|
|||
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
ModelRegistryHelper,
|
||||
build_hf_repo_model_alias,
|
||||
build_hf_repo_model_entry,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
OpenAICompatCompletionResponse,
|
||||
|
@ -62,9 +62,9 @@ from .config import VLLMInferenceAdapterConfig
|
|||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def build_hf_repo_model_aliases():
|
||||
def build_hf_repo_model_entries():
|
||||
return [
|
||||
build_hf_repo_model_alias(
|
||||
build_hf_repo_model_entry(
|
||||
model.huggingface_repo,
|
||||
model.descriptor(),
|
||||
)
|
||||
|
@ -204,7 +204,7 @@ async def _process_vllm_chat_completion_stream_response(
|
|||
|
||||
class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||
def __init__(self, config: VLLMInferenceAdapterConfig) -> None:
|
||||
self.register_helper = ModelRegistryHelper(build_hf_repo_model_aliases())
|
||||
self.register_helper = ModelRegistryHelper(build_hf_repo_model_entries())
|
||||
self.config = config
|
||||
self.client = None
|
||||
|
||||
|
|
|
@ -104,3 +104,6 @@ pytest llama_stack/providers/tests/ --config=ci_test_config.yaml
|
|||
Currently, we support test config on inference, agents and memory api tests.
|
||||
|
||||
Example format of test config can be found in ci_test_config.yaml.
|
||||
|
||||
## Test Data
|
||||
We encourage providers to use our test data for internal development testing, so to make it easier and consistent with the tests we provide. Each test case may define its own data format, and please refer to our test source code to get details on how these fields are used in the test.
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel, ValidationError
|
||||
from pydantic import BaseModel, TypeAdapter, ValidationError
|
||||
|
||||
from llama_stack.apis.common.content_types import ToolCallParseStatus
|
||||
from llama_stack.apis.inference import (
|
||||
|
@ -17,6 +17,7 @@ from llama_stack.apis.inference import (
|
|||
CompletionResponseStreamChunk,
|
||||
JsonSchemaResponseFormat,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
SystemMessage,
|
||||
ToolChoice,
|
||||
UserMessage,
|
||||
|
@ -30,6 +31,7 @@ from llama_stack.models.llama.datatypes import (
|
|||
ToolParamDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.providers.tests.test_cases.test_case import TestCase
|
||||
|
||||
from .utils import group_chunks
|
||||
|
||||
|
@ -178,8 +180,9 @@ class TestInference:
|
|||
else: # no token, no logprobs
|
||||
assert not chunk.logprobs, "Logprobs should be empty"
|
||||
|
||||
@pytest.mark.parametrize("test_case", ["completion-01"])
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_completion_structured_output(self, inference_model, inference_stack):
|
||||
async def test_completion_structured_output(self, inference_model, inference_stack, test_case):
|
||||
inference_impl, _ = inference_stack
|
||||
|
||||
class Output(BaseModel):
|
||||
|
@ -187,7 +190,9 @@ class TestInference:
|
|||
year_born: str
|
||||
year_retired: str
|
||||
|
||||
user_input = "Michael Jordan was born in 1963. He played basketball for the Chicago Bulls. He retired in 2003."
|
||||
tc = TestCase(test_case)
|
||||
|
||||
user_input = tc["user_input"]
|
||||
response = await inference_impl.completion(
|
||||
model_id=inference_model,
|
||||
content=user_input,
|
||||
|
@ -203,9 +208,10 @@ class TestInference:
|
|||
assert isinstance(response.content, str)
|
||||
|
||||
answer = Output.model_validate_json(response.content)
|
||||
assert answer.name == "Michael Jordan"
|
||||
assert answer.year_born == "1963"
|
||||
assert answer.year_retired == "2003"
|
||||
expected = tc["expected"]
|
||||
assert answer.name == expected["name"]
|
||||
assert answer.year_born == expected["year_born"]
|
||||
assert answer.year_retired == expected["year_retired"]
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_chat_completion_non_streaming(
|
||||
|
@ -224,8 +230,9 @@ class TestInference:
|
|||
assert isinstance(response.completion_message.content, str)
|
||||
assert len(response.completion_message.content) > 0
|
||||
|
||||
@pytest.mark.parametrize("test_case", ["chat_completion-01"])
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_structured_output(self, inference_model, inference_stack, common_params):
|
||||
async def test_structured_output(self, inference_model, inference_stack, common_params, test_case):
|
||||
inference_impl, _ = inference_stack
|
||||
|
||||
class AnswerFormat(BaseModel):
|
||||
|
@ -234,20 +241,12 @@ class TestInference:
|
|||
year_of_birth: int
|
||||
num_seasons_in_nba: int
|
||||
|
||||
tc = TestCase(test_case)
|
||||
messages = [TypeAdapter(Message).validate_python(m) for m in tc["messages"]]
|
||||
|
||||
response = await inference_impl.chat_completion(
|
||||
model_id=inference_model,
|
||||
messages=[
|
||||
# we include context about Michael Jordan in the prompt so that the test is
|
||||
# focused on the funtionality of the model and not on the information embedded
|
||||
# in the model. Llama 3.2 3B Instruct tends to think MJ played for 14 seasons.
|
||||
SystemMessage(
|
||||
content=(
|
||||
"You are a helpful assistant.\n\n"
|
||||
"Michael Jordan was born in 1963. He played basketball for the Chicago Bulls for 15 seasons."
|
||||
)
|
||||
),
|
||||
UserMessage(content="Please give me information about Michael Jordan."),
|
||||
],
|
||||
messages=messages,
|
||||
stream=False,
|
||||
response_format=JsonSchemaResponseFormat(
|
||||
json_schema=AnswerFormat.model_json_schema(),
|
||||
|
@ -260,10 +259,11 @@ class TestInference:
|
|||
assert isinstance(response.completion_message.content, str)
|
||||
|
||||
answer = AnswerFormat.model_validate_json(response.completion_message.content)
|
||||
assert answer.first_name == "Michael"
|
||||
assert answer.last_name == "Jordan"
|
||||
assert answer.year_of_birth == 1963
|
||||
assert answer.num_seasons_in_nba == 15
|
||||
expected = tc["expected"]
|
||||
assert answer.first_name == expected["first_name"]
|
||||
assert answer.last_name == expected["last_name"]
|
||||
assert answer.year_of_birth == expected["year_of_birth"]
|
||||
assert answer.num_seasons_in_nba == expected["num_seasons_in_nba"]
|
||||
|
||||
response = await inference_impl.chat_completion(
|
||||
model_id=inference_model,
|
||||
|
|
5
llama_stack/providers/tests/test_cases/__init__.py
Normal file
5
llama_stack/providers/tests/test_cases/__init__.py
Normal file
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
24
llama_stack/providers/tests/test_cases/chat_completion.json
Normal file
24
llama_stack/providers/tests/test_cases/chat_completion.json
Normal file
|
@ -0,0 +1,24 @@
|
|||
{
|
||||
"01": {
|
||||
"name": "structured output",
|
||||
"data": {
|
||||
"notes": "We include context about Michael Jordan in the prompt so that the test is focused on the funtionality of the model and not on the information embedded in the model. Llama 3.2 3B Instruct tends to think MJ played for 14 seasons.",
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant. Michael Jordan was born in 1963. He played basketball for the Chicago Bulls for 15 seasons."
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Please give me information about Michael Jordan."
|
||||
}
|
||||
],
|
||||
"expected": {
|
||||
"first_name": "Michael",
|
||||
"last_name": "Jordan",
|
||||
"year_of_birth": 1963,
|
||||
"num_seasons_in_nba": 15
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
13
llama_stack/providers/tests/test_cases/completion.json
Normal file
13
llama_stack/providers/tests/test_cases/completion.json
Normal file
|
@ -0,0 +1,13 @@
|
|||
{
|
||||
"01": {
|
||||
"name": "structured output",
|
||||
"data": {
|
||||
"user_input": "Michael Jordan was born in 1963. He played basketball for the Chicago Bulls. He retired in 2003.",
|
||||
"expected": {
|
||||
"name": "Michael Jordan",
|
||||
"year_born": "1963",
|
||||
"year_retired": "2003"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
32
llama_stack/providers/tests/test_cases/test_case.py
Normal file
32
llama_stack/providers/tests/test_cases/test_case.py
Normal file
|
@ -0,0 +1,32 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
import pathlib
|
||||
|
||||
|
||||
class TestCase:
|
||||
_apis = ["chat_completion", "completion"]
|
||||
_jsonblob = {}
|
||||
|
||||
def __init__(self, name):
|
||||
# loading all test cases
|
||||
if self._jsonblob == {}:
|
||||
for api in self._apis:
|
||||
with open(pathlib.Path(__file__).parent / f"{api}.json", "r") as f:
|
||||
TestCase._jsonblob.update({f"{api}-{k}": v for k, v in json.load(f).items()})
|
||||
|
||||
# loading this test case
|
||||
tc = self._jsonblob.get(name)
|
||||
if tc is None:
|
||||
raise ValueError(f"Test case {name} not found")
|
||||
|
||||
# these are the only fields we need
|
||||
self.name = tc.get("name")
|
||||
self.data = tc.get("data")
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self.data[key]
|
|
@ -18,7 +18,7 @@ from llama_stack.providers.utils.inference import (
|
|||
|
||||
# TODO: this class is more confusing than useful right now. We need to make it
|
||||
# more closer to the Model class.
|
||||
class ModelAlias(BaseModel):
|
||||
class ProviderModelEntry(BaseModel):
|
||||
provider_model_id: str
|
||||
aliases: List[str] = Field(default_factory=list)
|
||||
llama_model: Optional[str] = None
|
||||
|
@ -32,8 +32,8 @@ def get_huggingface_repo(model_descriptor: str) -> Optional[str]:
|
|||
return None
|
||||
|
||||
|
||||
def build_hf_repo_model_alias(provider_model_id: str, model_descriptor: str) -> ModelAlias:
|
||||
return ModelAlias(
|
||||
def build_hf_repo_model_entry(provider_model_id: str, model_descriptor: str) -> ProviderModelEntry:
|
||||
return ProviderModelEntry(
|
||||
provider_model_id=provider_model_id,
|
||||
aliases=[
|
||||
get_huggingface_repo(model_descriptor),
|
||||
|
@ -42,8 +42,8 @@ def build_hf_repo_model_alias(provider_model_id: str, model_descriptor: str) ->
|
|||
)
|
||||
|
||||
|
||||
def build_model_alias(provider_model_id: str, model_descriptor: str) -> ModelAlias:
|
||||
return ModelAlias(
|
||||
def build_model_entry(provider_model_id: str, model_descriptor: str) -> ProviderModelEntry:
|
||||
return ProviderModelEntry(
|
||||
provider_model_id=provider_model_id,
|
||||
aliases=[],
|
||||
llama_model=model_descriptor,
|
||||
|
@ -51,10 +51,10 @@ def build_model_alias(provider_model_id: str, model_descriptor: str) -> ModelAli
|
|||
|
||||
|
||||
class ModelRegistryHelper(ModelsProtocolPrivate):
|
||||
def __init__(self, model_aliases: List[ModelAlias]):
|
||||
def __init__(self, model_entries: List[ProviderModelEntry]):
|
||||
self.alias_to_provider_id_map = {}
|
||||
self.provider_id_to_llama_model_map = {}
|
||||
for alias_obj in model_aliases:
|
||||
for alias_obj in model_entries:
|
||||
for alias in alias_obj.aliases:
|
||||
self.alias_to_provider_id_map[alias] = alias_obj.provider_model_id
|
||||
# also add a mapping from provider model id to itself for easy lookup
|
||||
|
|
|
@ -10,7 +10,7 @@ from llama_stack.apis.models import ModelInput
|
|||
from llama_stack.distribution.datatypes import Provider, ToolGroupInput
|
||||
from llama_stack.models.llama.sku_list import all_registered_models
|
||||
from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig
|
||||
from llama_stack.providers.remote.inference.bedrock.models import MODEL_ALIASES
|
||||
from llama_stack.providers.remote.inference.bedrock.models import MODEL_ENTRIES
|
||||
from llama_stack.templates.template import DistributionTemplate, RunConfigSettings
|
||||
|
||||
|
||||
|
@ -47,7 +47,7 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
provider_model_id=m.provider_model_id,
|
||||
provider_id="bedrock",
|
||||
)
|
||||
for m in MODEL_ALIASES
|
||||
for m in MODEL_ENTRIES
|
||||
]
|
||||
default_tool_groups = [
|
||||
ToolGroupInput(
|
||||
|
|
|
@ -14,7 +14,7 @@ from llama_stack.providers.inline.inference.sentence_transformers import (
|
|||
)
|
||||
from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig
|
||||
from llama_stack.providers.remote.inference.cerebras import CerebrasImplConfig
|
||||
from llama_stack.providers.remote.inference.cerebras.models import model_aliases
|
||||
from llama_stack.providers.remote.inference.cerebras.models import model_entries
|
||||
from llama_stack.templates.template import DistributionTemplate, RunConfigSettings
|
||||
|
||||
|
||||
|
@ -55,7 +55,7 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
provider_model_id=m.provider_model_id,
|
||||
provider_id="cerebras",
|
||||
)
|
||||
for m in model_aliases
|
||||
for m in model_entries
|
||||
]
|
||||
embedding_model = ModelInput(
|
||||
model_id="all-MiniLM-L6-v2",
|
||||
|
|
|
@ -19,7 +19,7 @@ from llama_stack.providers.inline.inference.sentence_transformers import (
|
|||
)
|
||||
from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig
|
||||
from llama_stack.providers.remote.inference.fireworks import FireworksImplConfig
|
||||
from llama_stack.providers.remote.inference.fireworks.models import MODEL_ALIASES
|
||||
from llama_stack.providers.remote.inference.fireworks.models import MODEL_ENTRIES
|
||||
from llama_stack.templates.template import DistributionTemplate, RunConfigSettings
|
||||
|
||||
|
||||
|
@ -67,7 +67,7 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
provider_model_id=m.provider_model_id,
|
||||
provider_id="fireworks",
|
||||
)
|
||||
for m in MODEL_ALIASES
|
||||
for m in MODEL_ENTRIES
|
||||
]
|
||||
embedding_model = ModelInput(
|
||||
model_id="all-MiniLM-L6-v2",
|
||||
|
|
|
@ -9,7 +9,7 @@ from pathlib import Path
|
|||
from llama_stack.distribution.datatypes import ModelInput, Provider, ToolGroupInput
|
||||
from llama_stack.models.llama.sku_list import all_registered_models
|
||||
from llama_stack.providers.remote.inference.nvidia import NVIDIAConfig
|
||||
from llama_stack.providers.remote.inference.nvidia.models import _MODEL_ALIASES
|
||||
from llama_stack.providers.remote.inference.nvidia.models import _MODEL_ENTRIES
|
||||
from llama_stack.templates.template import DistributionTemplate, RunConfigSettings
|
||||
|
||||
|
||||
|
@ -45,7 +45,7 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
provider_model_id=m.provider_model_id,
|
||||
provider_id="nvidia",
|
||||
)
|
||||
for m in _MODEL_ALIASES
|
||||
for m in _MODEL_ENTRIES
|
||||
]
|
||||
default_tool_groups = [
|
||||
ToolGroupInput(
|
||||
|
|
|
@ -119,7 +119,7 @@ llama stack run ./run-with-safety.yaml \
|
|||
### (Optional) Update Model Serving Configuration
|
||||
|
||||
```{note}
|
||||
Please check the [model_aliases](https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/inference/ollama/ollama.py#L45) for the supported Ollama models.
|
||||
Please check the [model_entries](https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/inference/ollama/ollama.py#L45) for the supported Ollama models.
|
||||
```
|
||||
|
||||
To serve a new model with `ollama`
|
||||
|
|
|
@ -14,7 +14,7 @@ from llama_stack.distribution.datatypes import (
|
|||
)
|
||||
from llama_stack.models.llama.sku_list import all_registered_models
|
||||
from llama_stack.providers.remote.inference.sambanova import SambaNovaImplConfig
|
||||
from llama_stack.providers.remote.inference.sambanova.models import MODEL_ALIASES
|
||||
from llama_stack.providers.remote.inference.sambanova.models import MODEL_ENTRIES
|
||||
from llama_stack.templates.template import DistributionTemplate, RunConfigSettings
|
||||
|
||||
|
||||
|
@ -47,7 +47,7 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
provider_model_id=m.provider_model_id,
|
||||
provider_id=name,
|
||||
)
|
||||
for m in MODEL_ALIASES
|
||||
for m in MODEL_ENTRIES
|
||||
]
|
||||
|
||||
default_tool_groups = [
|
||||
|
|
|
@ -19,7 +19,7 @@ from llama_stack.providers.inline.inference.sentence_transformers import (
|
|||
)
|
||||
from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig
|
||||
from llama_stack.providers.remote.inference.together import TogetherImplConfig
|
||||
from llama_stack.providers.remote.inference.together.models import MODEL_ALIASES
|
||||
from llama_stack.providers.remote.inference.together.models import MODEL_ENTRIES
|
||||
from llama_stack.templates.template import DistributionTemplate, RunConfigSettings
|
||||
|
||||
|
||||
|
@ -65,7 +65,7 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
provider_model_id=m.provider_model_id,
|
||||
provider_id="together",
|
||||
)
|
||||
for m in MODEL_ALIASES
|
||||
for m in MODEL_ENTRIES
|
||||
]
|
||||
default_tool_groups = [
|
||||
ToolGroupInput(
|
||||
|
|
|
@ -7,6 +7,8 @@
|
|||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.providers.tests.test_cases.test_case import TestCase
|
||||
|
||||
PROVIDER_TOOL_PROMPT_FORMAT = {
|
||||
"remote::ollama": "json",
|
||||
"remote::together": "json",
|
||||
|
@ -120,16 +122,16 @@ def test_completion_log_probs_streaming(llama_stack_client, text_model_id, infer
|
|||
assert not chunk.logprobs, "Logprobs should be empty"
|
||||
|
||||
|
||||
def test_text_completion_structured_output(llama_stack_client, text_model_id, inference_provider_type):
|
||||
user_input = """
|
||||
Michael Jordan was born in 1963. He played basketball for the Chicago Bulls. He retired in 2003.
|
||||
"""
|
||||
|
||||
@pytest.mark.parametrize("test_case", ["completion-01"])
|
||||
def test_text_completion_structured_output(llama_stack_client, text_model_id, inference_provider_type, test_case):
|
||||
class AnswerFormat(BaseModel):
|
||||
name: str
|
||||
year_born: str
|
||||
year_retired: str
|
||||
|
||||
tc = TestCase(test_case)
|
||||
|
||||
user_input = tc["user_input"]
|
||||
response = llama_stack_client.inference.completion(
|
||||
model_id=text_model_id,
|
||||
content=user_input,
|
||||
|
@ -143,9 +145,10 @@ def test_text_completion_structured_output(llama_stack_client, text_model_id, in
|
|||
},
|
||||
)
|
||||
answer = AnswerFormat.model_validate_json(response.content)
|
||||
assert answer.name == "Michael Jordan"
|
||||
assert answer.year_born == "1963"
|
||||
assert answer.year_retired == "2003"
|
||||
expected = tc["expected"]
|
||||
assert answer.name == expected["name"]
|
||||
assert answer.year_born == expected["year_born"]
|
||||
assert answer.year_retired == expected["year_retired"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
@ -247,6 +250,7 @@ def test_text_chat_completion_with_tool_calling_and_streaming(
|
|||
assert tool_invocation_content == "[get_weather, {'location': 'San Francisco, CA'}]"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("test_case", ["chat_completion-01"])
|
||||
def test_text_chat_completion_with_tool_choice_required(
|
||||
llama_stack_client, text_model_id, get_weather_tool_definition, provider_tool_format, inference_provider_type
|
||||
):
|
||||
|
@ -281,25 +285,18 @@ def test_text_chat_completion_with_tool_choice_none(
|
|||
assert tool_invocation_content == ""
|
||||
|
||||
|
||||
def test_text_chat_completion_structured_output(llama_stack_client, text_model_id, inference_provider_type):
|
||||
def test_text_chat_completion_structured_output(llama_stack_client, text_model_id, inference_provider_type, test_case):
|
||||
class AnswerFormat(BaseModel):
|
||||
first_name: str
|
||||
last_name: str
|
||||
year_of_birth: int
|
||||
num_seasons_in_nba: int
|
||||
|
||||
tc = TestCase(test_case)
|
||||
|
||||
response = llama_stack_client.inference.chat_completion(
|
||||
model_id=text_model_id,
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant. Michael Jordan was born in 1963. He played basketball for the Chicago Bulls for 15 seasons.",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Please give me information about Michael Jordan.",
|
||||
},
|
||||
],
|
||||
messages=tc["messages"],
|
||||
response_format={
|
||||
"type": "json_schema",
|
||||
"json_schema": AnswerFormat.model_json_schema(),
|
||||
|
@ -307,10 +304,11 @@ def test_text_chat_completion_structured_output(llama_stack_client, text_model_i
|
|||
stream=False,
|
||||
)
|
||||
answer = AnswerFormat.model_validate_json(response.completion_message.content)
|
||||
assert answer.first_name == "Michael"
|
||||
assert answer.last_name == "Jordan"
|
||||
assert answer.year_of_birth == 1963
|
||||
assert answer.num_seasons_in_nba == 15
|
||||
expected = tc["expected"]
|
||||
assert answer.first_name == expected["first_name"]
|
||||
assert answer.last_name == expected["last_name"]
|
||||
assert answer.year_of_birth == expected["year_of_birth"]
|
||||
assert answer.num_seasons_in_nba == expected["num_seasons_in_nba"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue