mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-03 01:03:59 +00:00
Fix LibraryClient completely correctly; also make tests pass
This commit is contained in:
parent
d4935ca439
commit
1bcc26ccd1
7 changed files with 201 additions and 100 deletions
|
@ -13,10 +13,20 @@ import threading
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Generator, get_args, get_origin, Optional, Type, TypeVar, Union
|
from typing import Any, Generator, get_args, get_origin, Optional, TypeVar
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
from llama_stack_client import AsyncLlamaStackClient, LlamaStackClient, NOT_GIVEN
|
from llama_stack_client import (
|
||||||
|
APIResponse,
|
||||||
|
AsyncAPIResponse,
|
||||||
|
AsyncLlamaStackClient,
|
||||||
|
AsyncStream,
|
||||||
|
LlamaStackClient,
|
||||||
|
NOT_GIVEN,
|
||||||
|
Stream,
|
||||||
|
)
|
||||||
from pydantic import BaseModel, TypeAdapter
|
from pydantic import BaseModel, TypeAdapter
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
|
|
||||||
|
@ -66,7 +76,7 @@ def stream_across_asyncio_run_boundary(
|
||||||
# make sure we make the generator in the event loop context
|
# make sure we make the generator in the event loop context
|
||||||
gen = await async_gen_maker()
|
gen = await async_gen_maker()
|
||||||
try:
|
try:
|
||||||
async for item in gen:
|
async for item in await gen:
|
||||||
result_queue.put(item)
|
result_queue.put(item)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error in generator {e}")
|
print(f"Error in generator {e}")
|
||||||
|
@ -112,31 +122,17 @@ def stream_across_asyncio_run_boundary(
|
||||||
future.result()
|
future.result()
|
||||||
|
|
||||||
|
|
||||||
def convert_pydantic_to_json_value(value: Any, cast_to: Type) -> dict:
|
def convert_pydantic_to_json_value(value: Any) -> Any:
|
||||||
if isinstance(value, Enum):
|
if isinstance(value, Enum):
|
||||||
return value.value
|
return value.value
|
||||||
elif isinstance(value, list):
|
elif isinstance(value, list):
|
||||||
return [convert_pydantic_to_json_value(item, cast_to) for item in value]
|
return [convert_pydantic_to_json_value(item) for item in value]
|
||||||
elif isinstance(value, dict):
|
elif isinstance(value, dict):
|
||||||
return {k: convert_pydantic_to_json_value(v, cast_to) for k, v in value.items()}
|
return {k: convert_pydantic_to_json_value(v) for k, v in value.items()}
|
||||||
elif isinstance(value, BaseModel):
|
elif isinstance(value, BaseModel):
|
||||||
# This is quite hacky and we should figure out how to use stuff from
|
return json.loads(value.model_dump_json())
|
||||||
# generated client-sdk code (using ApiResponse.parse() essentially)
|
else:
|
||||||
value_dict = json.loads(value.model_dump_json())
|
return value
|
||||||
|
|
||||||
origin = get_origin(cast_to)
|
|
||||||
if origin is Union:
|
|
||||||
args = get_args(cast_to)
|
|
||||||
for arg in args:
|
|
||||||
arg_name = arg.__name__.split(".")[-1]
|
|
||||||
value_name = value.__class__.__name__.split(".")[-1]
|
|
||||||
if arg_name == value_name:
|
|
||||||
return arg(**value_dict)
|
|
||||||
|
|
||||||
# assume we have the correct association between the server-side type and the client-side type
|
|
||||||
return cast_to(**value_dict)
|
|
||||||
|
|
||||||
return value
|
|
||||||
|
|
||||||
|
|
||||||
def convert_to_pydantic(annotation: Any, value: Any) -> Any:
|
def convert_to_pydantic(annotation: Any, value: Any) -> Any:
|
||||||
|
@ -278,16 +274,28 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
if not self.endpoint_impls:
|
if not self.endpoint_impls:
|
||||||
raise ValueError("Client not initialized")
|
raise ValueError("Client not initialized")
|
||||||
|
|
||||||
params = options.params or {}
|
|
||||||
params |= options.json_data or {}
|
|
||||||
if stream:
|
if stream:
|
||||||
return self._call_streaming(options.url, params, cast_to)
|
return self._call_streaming(
|
||||||
|
cast_to=cast_to,
|
||||||
|
options=options,
|
||||||
|
stream_cls=stream_cls,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
return await self._call_non_streaming(options.url, params, cast_to)
|
return await self._call_non_streaming(
|
||||||
|
cast_to=cast_to,
|
||||||
|
options=options,
|
||||||
|
)
|
||||||
|
|
||||||
async def _call_non_streaming(
|
async def _call_non_streaming(
|
||||||
self, path: str, body: dict = None, cast_to: Any = None
|
self,
|
||||||
|
*,
|
||||||
|
cast_to: Any,
|
||||||
|
options: Any,
|
||||||
):
|
):
|
||||||
|
path = options.url
|
||||||
|
|
||||||
|
body = options.params or {}
|
||||||
|
body |= options.json_data or {}
|
||||||
await start_trace(path, {"__location__": "library_client"})
|
await start_trace(path, {"__location__": "library_client"})
|
||||||
try:
|
try:
|
||||||
func = self.endpoint_impls.get(path)
|
func = self.endpoint_impls.get(path)
|
||||||
|
@ -295,11 +303,45 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
raise ValueError(f"No endpoint found for {path}")
|
raise ValueError(f"No endpoint found for {path}")
|
||||||
|
|
||||||
body = self._convert_body(path, body)
|
body = self._convert_body(path, body)
|
||||||
return convert_pydantic_to_json_value(await func(**body), cast_to)
|
result = await func(**body)
|
||||||
|
|
||||||
|
json_content = json.dumps(convert_pydantic_to_json_value(result))
|
||||||
|
mock_response = httpx.Response(
|
||||||
|
status_code=httpx.codes.OK,
|
||||||
|
content=json_content.encode("utf-8"),
|
||||||
|
headers={
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
request=httpx.Request(
|
||||||
|
method=options.method,
|
||||||
|
url=options.url,
|
||||||
|
params=options.params,
|
||||||
|
headers=options.headers,
|
||||||
|
json=options.json_data,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
response = APIResponse(
|
||||||
|
raw=mock_response,
|
||||||
|
client=self,
|
||||||
|
cast_to=cast_to,
|
||||||
|
options=options,
|
||||||
|
stream=False,
|
||||||
|
stream_cls=None,
|
||||||
|
)
|
||||||
|
return response.parse()
|
||||||
finally:
|
finally:
|
||||||
await end_trace()
|
await end_trace()
|
||||||
|
|
||||||
async def _call_streaming(self, path: str, body: dict = None, cast_to: Any = None):
|
async def _call_streaming(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
cast_to: Any,
|
||||||
|
options: Any,
|
||||||
|
stream_cls: Any,
|
||||||
|
):
|
||||||
|
path = options.url
|
||||||
|
body = options.params or {}
|
||||||
|
body |= options.json_data or {}
|
||||||
await start_trace(path, {"__location__": "library_client"})
|
await start_trace(path, {"__location__": "library_client"})
|
||||||
try:
|
try:
|
||||||
func = self.endpoint_impls.get(path)
|
func = self.endpoint_impls.get(path)
|
||||||
|
@ -307,8 +349,40 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
raise ValueError(f"No endpoint found for {path}")
|
raise ValueError(f"No endpoint found for {path}")
|
||||||
|
|
||||||
body = self._convert_body(path, body)
|
body = self._convert_body(path, body)
|
||||||
async for chunk in await func(**body):
|
|
||||||
yield convert_pydantic_to_json_value(chunk, cast_to)
|
async def gen():
|
||||||
|
async for chunk in await func(**body):
|
||||||
|
data = json.dumps(convert_pydantic_to_json_value(chunk))
|
||||||
|
sse_event = f"data: {data}\n\n"
|
||||||
|
yield sse_event.encode("utf-8")
|
||||||
|
|
||||||
|
mock_response = httpx.Response(
|
||||||
|
status_code=httpx.codes.OK,
|
||||||
|
content=gen(),
|
||||||
|
headers={
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
request=httpx.Request(
|
||||||
|
method=options.method,
|
||||||
|
url=options.url,
|
||||||
|
params=options.params,
|
||||||
|
headers=options.headers,
|
||||||
|
json=options.json_data,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
origin = get_origin(stream_cls)
|
||||||
|
assert origin is Stream
|
||||||
|
args = get_args(stream_cls)
|
||||||
|
stream_cls = AsyncStream[args[0]]
|
||||||
|
response = AsyncAPIResponse(
|
||||||
|
raw=mock_response,
|
||||||
|
client=self,
|
||||||
|
cast_to=cast_to,
|
||||||
|
options=options,
|
||||||
|
stream=True,
|
||||||
|
stream_cls=stream_cls,
|
||||||
|
)
|
||||||
|
return await response.parse()
|
||||||
finally:
|
finally:
|
||||||
await end_trace()
|
await end_trace()
|
||||||
|
|
||||||
|
|
|
@ -6,6 +6,7 @@
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import re
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
@ -143,7 +144,7 @@ def replace_env_vars(config: Any, path: str = "") -> Any:
|
||||||
if default_val is None:
|
if default_val is None:
|
||||||
raise EnvVarError(env_var, path)
|
raise EnvVarError(env_var, path)
|
||||||
else:
|
else:
|
||||||
value = default_val
|
value = default_val if default_val != "null" else None
|
||||||
|
|
||||||
# expand "~" from the values
|
# expand "~" from the values
|
||||||
return os.path.expanduser(value)
|
return os.path.expanduser(value)
|
||||||
|
|
|
@ -5,7 +5,6 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from typing import Dict, List, Optional, Protocol, Tuple
|
from typing import Dict, List, Optional, Protocol, Tuple
|
||||||
|
|
||||||
|
@ -54,10 +53,7 @@ def _parse_registry_values(values: List[str]) -> List[RoutableObjectWithProvider
|
||||||
"""Utility function to parse registry values into RoutableObjectWithProvider objects."""
|
"""Utility function to parse registry values into RoutableObjectWithProvider objects."""
|
||||||
all_objects = []
|
all_objects = []
|
||||||
for value in values:
|
for value in values:
|
||||||
obj = pydantic.parse_obj_as(
|
obj = pydantic.TypeAdapter(RoutableObjectWithProvider).validate_json(value)
|
||||||
RoutableObjectWithProvider,
|
|
||||||
json.loads(value),
|
|
||||||
)
|
|
||||||
all_objects.append(obj)
|
all_objects.append(obj)
|
||||||
return all_objects
|
return all_objects
|
||||||
|
|
||||||
|
@ -89,14 +85,7 @@ class DiskDistributionRegistry(DistributionRegistry):
|
||||||
if not json_str:
|
if not json_str:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
objects_data = json.loads(json_str)
|
return pydantic.TypeAdapter(RoutableObjectWithProvider).validate_json(json_str)
|
||||||
# Return only the first object if any exist
|
|
||||||
if objects_data:
|
|
||||||
return pydantic.parse_obj_as(
|
|
||||||
RoutableObjectWithProvider,
|
|
||||||
json.loads(objects_data),
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def update(self, obj: RoutableObjectWithProvider) -> None:
|
async def update(self, obj: RoutableObjectWithProvider) -> None:
|
||||||
await self.kvstore.set(
|
await self.kvstore.set(
|
||||||
|
|
|
@ -8,6 +8,7 @@ import json
|
||||||
from typing import Dict, List
|
from typing import Dict, List
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
|
import pytest
|
||||||
from llama_stack.providers.tests.env import get_env_or_fail
|
from llama_stack.providers.tests.env import get_env_or_fail
|
||||||
|
|
||||||
from llama_stack_client.lib.agents.agent import Agent
|
from llama_stack_client.lib.agents.agent import Agent
|
||||||
|
@ -77,16 +78,20 @@ class TestCustomTool(CustomTool):
|
||||||
return -1
|
return -1
|
||||||
|
|
||||||
|
|
||||||
def get_agent_config_with_available_models_shields(llama_stack_client):
|
@pytest.fixture(scope="session")
|
||||||
|
def agent_config(llama_stack_client):
|
||||||
available_models = [
|
available_models = [
|
||||||
model.identifier
|
model.identifier
|
||||||
for model in llama_stack_client.models.list()
|
for model in llama_stack_client.models.list()
|
||||||
if model.identifier.startswith("meta-llama")
|
if model.identifier.startswith("meta-llama") and "405" not in model.identifier
|
||||||
]
|
]
|
||||||
model_id = available_models[0]
|
model_id = available_models[0]
|
||||||
|
print(f"Using model: {model_id}")
|
||||||
available_shields = [
|
available_shields = [
|
||||||
shield.identifier for shield in llama_stack_client.shields.list()
|
shield.identifier for shield in llama_stack_client.shields.list()
|
||||||
]
|
]
|
||||||
|
available_shields = available_shields[:1]
|
||||||
|
print(f"Using shield: {available_shields}")
|
||||||
agent_config = AgentConfig(
|
agent_config = AgentConfig(
|
||||||
model=model_id,
|
model=model_id,
|
||||||
instructions="You are a helpful assistant",
|
instructions="You are a helpful assistant",
|
||||||
|
@ -105,8 +110,7 @@ def get_agent_config_with_available_models_shields(llama_stack_client):
|
||||||
return agent_config
|
return agent_config
|
||||||
|
|
||||||
|
|
||||||
def test_agent_simple(llama_stack_client):
|
def test_agent_simple(llama_stack_client, agent_config):
|
||||||
agent_config = get_agent_config_with_available_models_shields(llama_stack_client)
|
|
||||||
agent = Agent(llama_stack_client, agent_config)
|
agent = Agent(llama_stack_client, agent_config)
|
||||||
session_id = agent.create_session(f"test-session-{uuid4()}")
|
session_id = agent.create_session(f"test-session-{uuid4()}")
|
||||||
|
|
||||||
|
@ -142,16 +146,18 @@ def test_agent_simple(llama_stack_client):
|
||||||
assert "I can't" in logs_str
|
assert "I can't" in logs_str
|
||||||
|
|
||||||
|
|
||||||
def test_builtin_tool_brave_search(llama_stack_client):
|
def test_builtin_tool_brave_search(llama_stack_client, agent_config):
|
||||||
agent_config = get_agent_config_with_available_models_shields(llama_stack_client)
|
agent_config = {
|
||||||
agent_config["tools"] = [
|
**agent_config,
|
||||||
{
|
"tools": [
|
||||||
"type": "brave_search",
|
{
|
||||||
"engine": "brave",
|
"type": "brave_search",
|
||||||
"api_key": get_env_or_fail("BRAVE_SEARCH_API_KEY"),
|
"engine": "brave",
|
||||||
}
|
"api_key": get_env_or_fail("BRAVE_SEARCH_API_KEY"),
|
||||||
]
|
}
|
||||||
print(agent_config)
|
],
|
||||||
|
}
|
||||||
|
print(f"Agent Config: {agent_config}")
|
||||||
agent = Agent(llama_stack_client, agent_config)
|
agent = Agent(llama_stack_client, agent_config)
|
||||||
session_id = agent.create_session(f"test-session-{uuid4()}")
|
session_id = agent.create_session(f"test-session-{uuid4()}")
|
||||||
|
|
||||||
|
@ -174,13 +180,15 @@ def test_builtin_tool_brave_search(llama_stack_client):
|
||||||
assert "No Violation" in logs_str
|
assert "No Violation" in logs_str
|
||||||
|
|
||||||
|
|
||||||
def test_builtin_tool_code_execution(llama_stack_client):
|
def test_builtin_tool_code_execution(llama_stack_client, agent_config):
|
||||||
agent_config = get_agent_config_with_available_models_shields(llama_stack_client)
|
agent_config = {
|
||||||
agent_config["tools"] = [
|
**agent_config,
|
||||||
{
|
"tools": [
|
||||||
"type": "code_interpreter",
|
{
|
||||||
}
|
"type": "code_interpreter",
|
||||||
]
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
agent = Agent(llama_stack_client, agent_config)
|
agent = Agent(llama_stack_client, agent_config)
|
||||||
session_id = agent.create_session(f"test-session-{uuid4()}")
|
session_id = agent.create_session(f"test-session-{uuid4()}")
|
||||||
|
|
||||||
|
@ -200,34 +208,36 @@ def test_builtin_tool_code_execution(llama_stack_client):
|
||||||
assert "Tool:code_interpreter Response" in logs_str
|
assert "Tool:code_interpreter Response" in logs_str
|
||||||
|
|
||||||
|
|
||||||
def test_custom_tool(llama_stack_client):
|
def test_custom_tool(llama_stack_client, agent_config):
|
||||||
agent_config = get_agent_config_with_available_models_shields(llama_stack_client)
|
agent_config = {
|
||||||
agent_config["model"] = "meta-llama/Llama-3.2-3B-Instruct"
|
**agent_config,
|
||||||
agent_config["tools"] = [
|
"model": "meta-llama/Llama-3.2-3B-Instruct",
|
||||||
{
|
"tools": [
|
||||||
"type": "brave_search",
|
{
|
||||||
"engine": "brave",
|
"type": "brave_search",
|
||||||
"api_key": get_env_or_fail("BRAVE_SEARCH_API_KEY"),
|
"engine": "brave",
|
||||||
},
|
"api_key": get_env_or_fail("BRAVE_SEARCH_API_KEY"),
|
||||||
{
|
|
||||||
"function_name": "get_boiling_point",
|
|
||||||
"description": "Get the boiling point of a imaginary liquids (eg. polyjuice)",
|
|
||||||
"parameters": {
|
|
||||||
"liquid_name": {
|
|
||||||
"param_type": "str",
|
|
||||||
"description": "The name of the liquid",
|
|
||||||
"required": True,
|
|
||||||
},
|
|
||||||
"celcius": {
|
|
||||||
"param_type": "boolean",
|
|
||||||
"description": "Whether to return the boiling point in Celcius",
|
|
||||||
"required": False,
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
"type": "function_call",
|
{
|
||||||
},
|
"function_name": "get_boiling_point",
|
||||||
]
|
"description": "Get the boiling point of a imaginary liquids (eg. polyjuice)",
|
||||||
agent_config["tool_prompt_format"] = "python_list"
|
"parameters": {
|
||||||
|
"liquid_name": {
|
||||||
|
"param_type": "str",
|
||||||
|
"description": "The name of the liquid",
|
||||||
|
"required": True,
|
||||||
|
},
|
||||||
|
"celcius": {
|
||||||
|
"param_type": "boolean",
|
||||||
|
"description": "Whether to return the boiling point in Celcius",
|
||||||
|
"required": False,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"type": "function_call",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"tool_prompt_format": "python_list",
|
||||||
|
}
|
||||||
|
|
||||||
agent = Agent(llama_stack_client, agent_config, custom_tools=(TestCustomTool(),))
|
agent = Agent(llama_stack_client, agent_config, custom_tools=(TestCustomTool(),))
|
||||||
session_id = agent.create_session(f"test-session-{uuid4()}")
|
session_id = agent.create_session(f"test-session-{uuid4()}")
|
||||||
|
|
|
@ -3,13 +3,22 @@
|
||||||
#
|
#
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
import os
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from llama_stack import LlamaStackAsLibraryClient
|
||||||
|
|
||||||
from llama_stack.providers.tests.env import get_env_or_fail
|
from llama_stack.providers.tests.env import get_env_or_fail
|
||||||
from llama_stack_client import LlamaStackClient
|
from llama_stack_client import LlamaStackClient
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture(scope="session")
|
||||||
def llama_stack_client():
|
def llama_stack_client():
|
||||||
"""Fixture to create a fresh LlamaStackClient instance for each test"""
|
if os.environ.get("LLAMA_STACK_CONFIG"):
|
||||||
return LlamaStackClient(base_url=get_env_or_fail("LLAMA_STACK_BASE_URL"))
|
client = LlamaStackAsLibraryClient(get_env_or_fail("LLAMA_STACK_CONFIG"))
|
||||||
|
client.initialize()
|
||||||
|
elif os.environ.get("LLAMA_STACK_BASE_URL"):
|
||||||
|
client = LlamaStackClient(base_url=get_env_or_fail("LLAMA_STACK_BASE_URL"))
|
||||||
|
else:
|
||||||
|
raise ValueError("LLAMA_STACK_CONFIG or LLAMA_STACK_BASE_URL must be set")
|
||||||
|
return client
|
||||||
|
|
|
@ -4,10 +4,23 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import traceback
|
||||||
|
import warnings
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from llama_stack_client.lib.inference.event_logger import EventLogger
|
from llama_stack_client.lib.inference.event_logger import EventLogger
|
||||||
|
|
||||||
|
|
||||||
|
def warn_with_traceback(message, category, filename, lineno, file=None, line=None):
|
||||||
|
log = file if hasattr(file, "write") else sys.stderr
|
||||||
|
traceback.print_stack(file=log)
|
||||||
|
log.write(warnings.formatwarning(message, category, filename, lineno, line))
|
||||||
|
|
||||||
|
|
||||||
|
warnings.showwarning = warn_with_traceback
|
||||||
|
|
||||||
|
|
||||||
def test_text_chat_completion(llama_stack_client):
|
def test_text_chat_completion(llama_stack_client):
|
||||||
# non-streaming
|
# non-streaming
|
||||||
available_models = [
|
available_models = [
|
||||||
|
@ -55,11 +68,15 @@ def test_image_chat_completion(llama_stack_client):
|
||||||
"role": "user",
|
"role": "user",
|
||||||
"content": [
|
"content": [
|
||||||
{
|
{
|
||||||
"image": {
|
"type": "image",
|
||||||
|
"data": {
|
||||||
"uri": "https://www.healthypawspetinsurance.com/Images/V3/DogAndPuppyInsurance/Dog_CTA_Desktop_HeroImage.jpg"
|
"uri": "https://www.healthypawspetinsurance.com/Images/V3/DogAndPuppyInsurance/Dog_CTA_Desktop_HeroImage.jpg"
|
||||||
}
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "Describe what is in this image.",
|
||||||
},
|
},
|
||||||
"Describe what is in this image.",
|
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
response = llama_stack_client.inference.chat_completion(
|
response = llama_stack_client.inference.chat_completion(
|
||||||
|
|
|
@ -22,6 +22,7 @@ def test_memory_bank(llama_stack_client):
|
||||||
llama_stack_client.memory_banks.register(
|
llama_stack_client.memory_banks.register(
|
||||||
memory_bank_id=memory_bank_id,
|
memory_bank_id=memory_bank_id,
|
||||||
params={
|
params={
|
||||||
|
"memory_bank_type": "vector",
|
||||||
"embedding_model": "all-MiniLM-L6-v2",
|
"embedding_model": "all-MiniLM-L6-v2",
|
||||||
"chunk_size_in_tokens": 512,
|
"chunk_size_in_tokens": 512,
|
||||||
"overlap_size_in_tokens": 64,
|
"overlap_size_in_tokens": 64,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue