From 75e429f950fcbff97af935c4bf300554afc16ed5 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Sat, 15 Feb 2025 14:33:08 -0800 Subject: [PATCH] fix rag test --- .../distribution/routers/routing_tables.py | 2 ++ .../sentence_transformers.py | 3 --- llama_stack/strong_typing/auxiliary.py | 2 +- llama_stack/strong_typing/classdef.py | 2 +- llama_stack/strong_typing/deserializer.py | 2 +- llama_stack/strong_typing/inspection.py | 6 ++++-- llama_stack/strong_typing/serializer.py | 2 +- .../templates/together/run-with-safety.yaml | 3 ++- llama_stack/templates/together/run.yaml | 3 ++- llama_stack/templates/together/together.py | 2 +- tests/client-sdk/agents/test_agents.py | 20 +++++++++++++------ tests/client-sdk/inference/test_embedding.py | 20 +++++++++++++++++++ 12 files changed, 49 insertions(+), 18 deletions(-) create mode 100644 tests/client-sdk/inference/test_embedding.py diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index 7543256fc..2cddc3970 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -236,6 +236,8 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models): metadata = {} if model_type is None: model_type = ModelType.llm + if "embedding_dimension" not in metadata and model_type == ModelType.embedding: + raise ValueError("Embedding model must have an embedding dimension in its metadata") model = Model( identifier=model_id, provider_resource_id=provider_model_id, diff --git a/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py b/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py index f605553ab..6a83836e6 100644 --- a/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py +++ b/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py @@ -19,7 +19,6 @@ from llama_stack.apis.inference import ( ToolDefinition, ToolPromptFormat, ) -from llama_stack.apis.models import ModelType from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate from llama_stack.providers.utils.inference.embedding_mixin import ( SentenceTransformerEmbeddingMixin, @@ -45,8 +44,6 @@ class SentenceTransformersInferenceImpl( pass async def register_model(self, model: Model) -> None: - if "embedding_dimension" not in model.metadata and model.model_type == ModelType.embedding: - raise ValueError("Embedding model must have an embedding dimension in its metadata") _ = self._load_sentence_transformer_model(model.provider_resource_id) return model diff --git a/llama_stack/strong_typing/auxiliary.py b/llama_stack/strong_typing/auxiliary.py index fd183da18..cf19d6083 100644 --- a/llama_stack/strong_typing/auxiliary.py +++ b/llama_stack/strong_typing/auxiliary.py @@ -77,7 +77,7 @@ def typeannotation( """ def wrap(cls: Type[T]) -> Type[T]: - setattr(cls, "__repr__", _compact_dataclass_repr) + cls.__repr__ = _compact_dataclass_repr if not dataclasses.is_dataclass(cls): cls = dataclasses.dataclass( # type: ignore[call-overload] cls, diff --git a/llama_stack/strong_typing/classdef.py b/llama_stack/strong_typing/classdef.py index d2d8688e4..5ead886d4 100644 --- a/llama_stack/strong_typing/classdef.py +++ b/llama_stack/strong_typing/classdef.py @@ -203,7 +203,7 @@ def schema_to_type(schema: Schema, *, module: types.ModuleType, class_name: str) if type_def.default is not dataclasses.MISSING: raise TypeError("disallowed: `default` for top-level type definitions") - setattr(type_def.type, "__module__", module.__name__) + type_def.type.__module__ = module.__name__ setattr(module, type_name, type_def.type) return node_to_typedef(module, class_name, top_node).type diff --git a/llama_stack/strong_typing/deserializer.py b/llama_stack/strong_typing/deserializer.py index 4c4ee9d89..fc0f40f83 100644 --- a/llama_stack/strong_typing/deserializer.py +++ b/llama_stack/strong_typing/deserializer.py @@ -325,7 +325,7 @@ class TupleDeserializer(Deserializer[Tuple[Any, ...]]): f"type `{self.container_type}` expects a JSON `array` of length {count} but received length {len(data)}" ) - return tuple(item_parser.parse(item) for item_parser, item in zip(self.item_parsers, data)) + return tuple(item_parser.parse(item) for item_parser, item in zip(self.item_parsers, data, strict=False)) class UnionDeserializer(Deserializer): diff --git a/llama_stack/strong_typing/inspection.py b/llama_stack/strong_typing/inspection.py index 69bc15597..8bc313021 100644 --- a/llama_stack/strong_typing/inspection.py +++ b/llama_stack/strong_typing/inspection.py @@ -263,8 +263,8 @@ def extend_enum( enum_class: Type[enum.Enum] = enum.Enum(extend.__name__, values) # type: ignore # assign the newly created type to the same module where the extending class is defined - setattr(enum_class, "__module__", extend.__module__) - setattr(enum_class, "__doc__", extend.__doc__) + enum_class.__module__ = extend.__module__ + enum_class.__doc__ = extend.__doc__ setattr(sys.modules[extend.__module__], extend.__name__, enum_class) return enum.unique(enum_class) @@ -874,6 +874,7 @@ def is_generic_instance(obj: Any, typ: TypeLike) -> bool: for tuple_item_type, item in zip( (tuple_item_type for tuple_item_type in typing.get_args(typ)), (item for item in obj), + strict=False, ) ) elif origin_type is Union: @@ -954,6 +955,7 @@ class RecursiveChecker: for tuple_item_type, item in zip( (tuple_item_type for tuple_item_type in typing.get_args(typ)), (item for item in obj), + strict=False, ) ) elif origin_type is Union: diff --git a/llama_stack/strong_typing/serializer.py b/llama_stack/strong_typing/serializer.py index 5e93e4c4d..4ca4a4119 100644 --- a/llama_stack/strong_typing/serializer.py +++ b/llama_stack/strong_typing/serializer.py @@ -216,7 +216,7 @@ class TypedTupleSerializer(Serializer[tuple]): self.item_generators = tuple(_get_serializer(item_type, context) for item_type in item_types) def generate(self, obj: tuple) -> List[JsonType]: - return [item_generator.generate(item) for item_generator, item in zip(self.item_generators, obj)] + return [item_generator.generate(item) for item_generator, item in zip(self.item_generators, obj, strict=False)] class CustomSerializer(Serializer): diff --git a/llama_stack/templates/together/run-with-safety.yaml b/llama_stack/templates/together/run-with-safety.yaml index a922ac249..52ee36e34 100644 --- a/llama_stack/templates/together/run-with-safety.yaml +++ b/llama_stack/templates/together/run-with-safety.yaml @@ -144,7 +144,8 @@ models: provider_id: together provider_model_id: meta-llama/Llama-Guard-3-11B-Vision-Turbo model_type: llm -- metadata: {} +- metadata: + embedding_dimension: 768 model_id: togethercomputer/m2-bert-80M-8k-retrieval provider_id: together provider_model_id: togethercomputer/m2-bert-80M-8k-retrieval diff --git a/llama_stack/templates/together/run.yaml b/llama_stack/templates/together/run.yaml index 5f40e1c97..d9fc5c5fc 100644 --- a/llama_stack/templates/together/run.yaml +++ b/llama_stack/templates/together/run.yaml @@ -138,7 +138,8 @@ models: provider_id: together provider_model_id: meta-llama/Llama-Guard-3-11B-Vision-Turbo model_type: llm -- metadata: {} +- metadata: + embedding_dimension: 768 model_id: togethercomputer/m2-bert-80M-8k-retrieval provider_id: together provider_model_id: togethercomputer/m2-bert-80M-8k-retrieval diff --git a/llama_stack/templates/together/together.py b/llama_stack/templates/together/together.py index d347b76a9..4c571db05 100644 --- a/llama_stack/templates/together/together.py +++ b/llama_stack/templates/together/together.py @@ -86,7 +86,7 @@ def get_distribution_template() -> DistributionTemplate: provider_id="together", model_type=ModelType.embedding, provider_model_id="togethercomputer/m2-bert-80M-8k-retrieval", - metadata={}, + metadata={"embedding_dimension": 768}, ) return DistributionTemplate( diff --git a/tests/client-sdk/agents/test_agents.py b/tests/client-sdk/agents/test_agents.py index e5c20c3a5..dee57ec95 100644 --- a/tests/client-sdk/agents/test_agents.py +++ b/tests/client-sdk/agents/test_agents.py @@ -19,8 +19,12 @@ from llama_stack_client.types.shared.completion_message import CompletionMessage from llama_stack_client.types.shared_params.agent_config import AgentConfig, ToolConfig from llama_stack_client.types.tool_def_param import Parameter -from llama_stack.apis.agents.agents import AgentConfig as Server__AgentConfig -from llama_stack.apis.agents.agents import ToolChoice +from llama_stack.apis.agents.agents import ( + AgentConfig as Server__AgentConfig, +) +from llama_stack.apis.agents.agents import ( + ToolChoice, +) class TestClientTool(ClientTool): @@ -417,11 +421,14 @@ def test_rag_agent(llama_stack_client, agent_config): ) for i, url in enumerate(urls) ] + embdding_models = [x for x in llama_stack_client.models.list() if x.model_type == "embedding"] + embedding_model = embdding_models[0] vector_db_id = f"test-vector-db-{uuid4()}" + llama_stack_client.vector_dbs.register( vector_db_id=vector_db_id, - embedding_model="all-MiniLM-L6-v2", - embedding_dimension=384, + embedding_model=embedding_model.identifier, + embedding_dimension=embedding_model.metadata["embedding_dimension"], ) llama_stack_client.tool_runtime.rag_tool.insert( documents=documents, @@ -444,11 +451,11 @@ def test_rag_agent(llama_stack_client, agent_config): session_id = rag_agent.create_session(f"test-session-{uuid4()}") user_prompts = [ ( - "Instead of the standard multi-head attention, what attention type does Llama3-8B use?", + "What is main changes between Llama2-7B and Llama3-8B models on how attention is used?", "grouped", ), ( - "What `tune` command to use for getting access to Llama3-8B-Instruct ?", + "What `tune` command to use for getting access to Llama3-8B-Instruct model?", "download", ), ] @@ -458,6 +465,7 @@ def test_rag_agent(llama_stack_client, agent_config): messages=[{"role": "user", "content": prompt}], session_id=session_id, ) + logs = [str(log) for log in EventLogger().log(response) if log is not None] logs_str = "".join(logs) assert "Tool:query_from_memory" in logs_str diff --git a/tests/client-sdk/inference/test_embedding.py b/tests/client-sdk/inference/test_embedding.py new file mode 100644 index 000000000..54469df5b --- /dev/null +++ b/tests/client-sdk/inference/test_embedding.py @@ -0,0 +1,20 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import pytest + + +def test_embedding(llama_stack_client): + emb_models = [x for x in llama_stack_client.models.list() if x.model_type == "embedding"] + if len(emb_models) == 0: + pytest.skip("No embedding models found") + + embedding_response = llama_stack_client.inference.embeddings( + model_id=emb_models[0].identifier, contents=["Hello, world!", "This is a test", "Testing embeddings"] + ) + assert embedding_response is not None + assert len(embedding_response.embeddings) == 3 + assert len(embedding_response.embeddings[0]) == emb_models[0].metadata["embedding_dimension"]