forked from phoenix-oss/llama-stack-mirror
Fix precommit check after moving to ruff (#927)
Lint check in main branch is failing. This fixes the lint check after we moved to ruff in https://github.com/meta-llama/llama-stack/pull/921. We need to move to a `ruff.toml` file as well as fixing and ignoring some additional checks. Signed-off-by: Yuan Tang <terrytangyuan@gmail.com>
This commit is contained in:
parent
4773092dd1
commit
34ab7a3b6c
217 changed files with 981 additions and 2681 deletions
|
@ -15,9 +15,7 @@ from llama_stack.providers.utils.bedrock.refreshable_boto_session import (
|
|||
)
|
||||
|
||||
|
||||
def create_bedrock_client(
|
||||
config: BedrockBaseConfig, service_name: str = "bedrock-runtime"
|
||||
) -> BaseClient:
|
||||
def create_bedrock_client(config: BedrockBaseConfig, service_name: str = "bedrock-runtime") -> BaseClient:
|
||||
"""Creates a boto3 client for Bedrock services with the given configuration.
|
||||
|
||||
Args:
|
||||
|
|
|
@ -28,8 +28,7 @@ class BedrockBaseConfig(BaseModel):
|
|||
)
|
||||
profile_name: Optional[str] = Field(
|
||||
default=None,
|
||||
description="The profile name that contains credentials to use."
|
||||
"Default use environment variable: AWS_PROFILE",
|
||||
description="The profile name that contains credentials to use.Default use environment variable: AWS_PROFILE",
|
||||
)
|
||||
total_max_attempts: Optional[int] = Field(
|
||||
default=None,
|
||||
|
|
|
@ -68,9 +68,7 @@ class RefreshableBotoSession:
|
|||
|
||||
# if sts_arn is given, get credential by assuming the given role
|
||||
if self.sts_arn:
|
||||
sts_client = session.client(
|
||||
service_name="sts", region_name=self.region_name
|
||||
)
|
||||
sts_client = session.client(service_name="sts", region_name=self.region_name)
|
||||
response = sts_client.assume_role(
|
||||
RoleArn=self.sts_arn,
|
||||
RoleSessionName=self.session_name,
|
||||
|
|
|
@ -68,9 +68,7 @@ def validate_dataset_schema(
|
|||
expected_schemas: List[Dict[str, Any]],
|
||||
):
|
||||
if dataset_schema not in expected_schemas:
|
||||
raise ValueError(
|
||||
f"Dataset {dataset_schema} does not have a correct input schema in {expected_schemas}"
|
||||
)
|
||||
raise ValueError(f"Dataset {dataset_schema} does not have a correct input schema in {expected_schemas}")
|
||||
|
||||
|
||||
def validate_row_schema(
|
||||
|
@ -81,6 +79,4 @@ def validate_row_schema(
|
|||
if all(key in input_row for key in schema):
|
||||
return
|
||||
|
||||
raise ValueError(
|
||||
f"Input row {input_row} does not match any of the expected schemas in {expected_schemas}"
|
||||
)
|
||||
raise ValueError(f"Input row {input_row} does not match any of the expected schemas in {expected_schemas}")
|
||||
|
|
|
@ -27,13 +27,10 @@ def supported_inference_models() -> List[Model]:
|
|||
m
|
||||
for m in all_registered_models()
|
||||
if (
|
||||
m.model_family
|
||||
in {ModelFamily.llama3_1, ModelFamily.llama3_2, ModelFamily.llama3_3}
|
||||
m.model_family in {ModelFamily.llama3_1, ModelFamily.llama3_2, ModelFamily.llama3_3}
|
||||
or is_supported_safety_model(m)
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR = {
|
||||
m.huggingface_repo: m.descriptor() for m in all_registered_models()
|
||||
}
|
||||
ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR = {m.huggingface_repo: m.descriptor() for m in all_registered_models()}
|
||||
|
|
|
@ -28,9 +28,7 @@ class SentenceTransformerEmbeddingMixin:
|
|||
contents: List[InterleavedContent],
|
||||
) -> EmbeddingsResponse:
|
||||
model = await self.model_store.get_model(model_id)
|
||||
embedding_model = self._load_sentence_transformer_model(
|
||||
model.provider_resource_id
|
||||
)
|
||||
embedding_model = self._load_sentence_transformer_model(model.provider_resource_id)
|
||||
embeddings = embedding_model.encode(contents)
|
||||
return EmbeddingsResponse(embeddings=embeddings)
|
||||
|
||||
|
|
|
@ -36,9 +36,7 @@ def build_model_alias(provider_model_id: str, model_descriptor: str) -> ModelAli
|
|||
)
|
||||
|
||||
|
||||
def build_model_alias_with_just_provider_model_id(
|
||||
provider_model_id: str, model_descriptor: str
|
||||
) -> ModelAlias:
|
||||
def build_model_alias_with_just_provider_model_id(provider_model_id: str, model_descriptor: str) -> ModelAlias:
|
||||
return ModelAlias(
|
||||
provider_model_id=provider_model_id,
|
||||
aliases=[],
|
||||
|
@ -54,16 +52,10 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
|
|||
for alias in alias_obj.aliases:
|
||||
self.alias_to_provider_id_map[alias] = alias_obj.provider_model_id
|
||||
# also add a mapping from provider model id to itself for easy lookup
|
||||
self.alias_to_provider_id_map[alias_obj.provider_model_id] = (
|
||||
alias_obj.provider_model_id
|
||||
)
|
||||
self.alias_to_provider_id_map[alias_obj.provider_model_id] = alias_obj.provider_model_id
|
||||
# ensure we can go from llama model to provider model id
|
||||
self.alias_to_provider_id_map[alias_obj.llama_model] = (
|
||||
alias_obj.provider_model_id
|
||||
)
|
||||
self.provider_id_to_llama_model_map[alias_obj.provider_model_id] = (
|
||||
alias_obj.llama_model
|
||||
)
|
||||
self.alias_to_provider_id_map[alias_obj.llama_model] = alias_obj.provider_model_id
|
||||
self.provider_id_to_llama_model_map[alias_obj.provider_model_id] = alias_obj.llama_model
|
||||
|
||||
def get_provider_model_id(self, identifier: str) -> str:
|
||||
if identifier in self.alias_to_provider_id_map:
|
||||
|
@ -82,9 +74,7 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
|
|||
# embedding models are always registered by their provider model id and does not need to be mapped to a llama model
|
||||
provider_resource_id = model.provider_resource_id
|
||||
else:
|
||||
provider_resource_id = self.get_provider_model_id(
|
||||
model.provider_resource_id
|
||||
)
|
||||
provider_resource_id = self.get_provider_model_id(model.provider_resource_id)
|
||||
if provider_resource_id:
|
||||
model.provider_resource_id = provider_resource_id
|
||||
else:
|
||||
|
@ -100,18 +90,13 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
|
|||
f"Provider model id '{model.provider_resource_id}' is already registered to a different llama model: '{existing_llama_model}'"
|
||||
)
|
||||
else:
|
||||
if (
|
||||
model.metadata["llama_model"]
|
||||
not in ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR
|
||||
):
|
||||
if model.metadata["llama_model"] not in ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR:
|
||||
raise ValueError(
|
||||
f"Invalid llama_model '{model.metadata['llama_model']}' specified in metadata. "
|
||||
f"Must be one of: {', '.join(ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR.keys())}"
|
||||
)
|
||||
self.provider_id_to_llama_model_map[model.provider_resource_id] = (
|
||||
ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR[
|
||||
model.metadata["llama_model"]
|
||||
]
|
||||
ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR[model.metadata["llama_model"]]
|
||||
)
|
||||
|
||||
return model
|
||||
|
|
|
@ -135,9 +135,7 @@ def convert_openai_completion_logprobs(
|
|||
return None
|
||||
|
||||
|
||||
def convert_openai_completion_logprobs_stream(
|
||||
text: str, logprobs: Optional[Union[float, OpenAICompatLogprobs]]
|
||||
):
|
||||
def convert_openai_completion_logprobs_stream(text: str, logprobs: Optional[Union[float, OpenAICompatLogprobs]]):
|
||||
if logprobs is None:
|
||||
return None
|
||||
if isinstance(logprobs, float):
|
||||
|
@ -148,9 +146,7 @@ def convert_openai_completion_logprobs_stream(
|
|||
return None
|
||||
|
||||
|
||||
def process_completion_response(
|
||||
response: OpenAICompatCompletionResponse, formatter: ChatFormat
|
||||
) -> CompletionResponse:
|
||||
def process_completion_response(response: OpenAICompatCompletionResponse, formatter: ChatFormat) -> CompletionResponse:
|
||||
choice = response.choices[0]
|
||||
# drop suffix <eot_id> if present and return stop reason as end of turn
|
||||
if choice.text.endswith("<|eot_id|>"):
|
||||
|
@ -341,17 +337,13 @@ async def process_chat_completion_stream_response(
|
|||
)
|
||||
|
||||
|
||||
async def convert_message_to_openai_dict(
|
||||
message: Message, download: bool = False
|
||||
) -> dict:
|
||||
async def convert_message_to_openai_dict(message: Message, download: bool = False) -> dict:
|
||||
async def _convert_content(content) -> dict:
|
||||
if isinstance(content, ImageContentItem):
|
||||
return {
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": await convert_image_content_to_url(
|
||||
content, download=download
|
||||
),
|
||||
"url": await convert_image_content_to_url(content, download=download),
|
||||
},
|
||||
}
|
||||
else:
|
||||
|
|
|
@ -119,9 +119,7 @@ async def interleaved_content_convert_to_raw(
|
|||
if image.url.uri.startswith("data"):
|
||||
match = re.match(r"data:image/(\w+);base64,(.+)", image.url.uri)
|
||||
if not match:
|
||||
raise ValueError(
|
||||
f"Invalid data URL format, {image.url.uri[:40]}..."
|
||||
)
|
||||
raise ValueError(f"Invalid data URL format, {image.url.uri[:40]}...")
|
||||
_, image_data = match.groups()
|
||||
data = base64.b64decode(image_data)
|
||||
elif image.url.uri.startswith("file://"):
|
||||
|
@ -201,19 +199,13 @@ async def convert_image_content_to_url(
|
|||
|
||||
content, format = await localize_image_content(media)
|
||||
if include_format:
|
||||
return f"data:image/{format};base64," + base64.b64encode(content).decode(
|
||||
"utf-8"
|
||||
)
|
||||
return f"data:image/{format};base64," + base64.b64encode(content).decode("utf-8")
|
||||
else:
|
||||
return base64.b64encode(content).decode("utf-8")
|
||||
|
||||
|
||||
async def completion_request_to_prompt(
|
||||
request: CompletionRequest, formatter: ChatFormat
|
||||
) -> str:
|
||||
content = augment_content_with_response_format_prompt(
|
||||
request.response_format, request.content
|
||||
)
|
||||
async def completion_request_to_prompt(request: CompletionRequest, formatter: ChatFormat) -> str:
|
||||
content = augment_content_with_response_format_prompt(request.response_format, request.content)
|
||||
request.content = content
|
||||
request = await convert_request_to_raw(request)
|
||||
model_input = formatter.encode_content(request.content)
|
||||
|
@ -223,9 +215,7 @@ async def completion_request_to_prompt(
|
|||
async def completion_request_to_prompt_model_input_info(
|
||||
request: CompletionRequest, formatter: ChatFormat
|
||||
) -> Tuple[str, int]:
|
||||
content = augment_content_with_response_format_prompt(
|
||||
request.response_format, request.content
|
||||
)
|
||||
content = augment_content_with_response_format_prompt(request.response_format, request.content)
|
||||
request.content = content
|
||||
request = await convert_request_to_raw(request)
|
||||
model_input = formatter.encode_content(request.content)
|
||||
|
@ -288,8 +278,7 @@ def chat_completion_request_to_messages(
|
|||
return request.messages
|
||||
|
||||
if model.model_family == ModelFamily.llama3_1 or (
|
||||
model.model_family == ModelFamily.llama3_2
|
||||
and is_multimodal(model.core_model_id)
|
||||
model.model_family == ModelFamily.llama3_2 and is_multimodal(model.core_model_id)
|
||||
):
|
||||
# llama3.1 and llama3.2 multimodal models follow the same tool prompt format
|
||||
messages = augment_messages_for_tools_llama_3_1(request)
|
||||
|
@ -327,9 +316,7 @@ def augment_messages_for_tools_llama_3_1(
|
|||
if existing_messages[0].role == Role.system.value:
|
||||
existing_system_message = existing_messages.pop(0)
|
||||
|
||||
assert (
|
||||
existing_messages[0].role != Role.system.value
|
||||
), "Should only have 1 system message"
|
||||
assert existing_messages[0].role != Role.system.value, "Should only have 1 system message"
|
||||
|
||||
messages = []
|
||||
|
||||
|
@ -361,9 +348,7 @@ def augment_messages_for_tools_llama_3_1(
|
|||
if isinstance(existing_system_message.content, str):
|
||||
sys_content += _process(existing_system_message.content)
|
||||
elif isinstance(existing_system_message.content, list):
|
||||
sys_content += "\n".join(
|
||||
[_process(c) for c in existing_system_message.content]
|
||||
)
|
||||
sys_content += "\n".join([_process(c) for c in existing_system_message.content])
|
||||
|
||||
messages.append(SystemMessage(content=sys_content))
|
||||
|
||||
|
@ -397,9 +382,7 @@ def augment_messages_for_tools_llama_3_2(
|
|||
if existing_messages[0].role == Role.system.value:
|
||||
existing_system_message = existing_messages.pop(0)
|
||||
|
||||
assert (
|
||||
existing_messages[0].role != Role.system.value
|
||||
), "Should only have 1 system message"
|
||||
assert existing_messages[0].role != Role.system.value, "Should only have 1 system message"
|
||||
|
||||
messages = []
|
||||
sys_content = ""
|
||||
|
@ -422,9 +405,7 @@ def augment_messages_for_tools_llama_3_2(
|
|||
if custom_tools:
|
||||
fmt = request.tool_prompt_format or ToolPromptFormat.python_list
|
||||
if fmt != ToolPromptFormat.python_list:
|
||||
raise ValueError(
|
||||
f"Non supported ToolPromptFormat {request.tool_prompt_format}"
|
||||
)
|
||||
raise ValueError(f"Non supported ToolPromptFormat {request.tool_prompt_format}")
|
||||
|
||||
tool_gen = PythonListCustomToolGenerator()
|
||||
tool_template = tool_gen.gen(custom_tools)
|
||||
|
@ -433,9 +414,7 @@ def augment_messages_for_tools_llama_3_2(
|
|||
sys_content += "\n"
|
||||
|
||||
if existing_system_message:
|
||||
sys_content += interleaved_content_as_str(
|
||||
existing_system_message.content, sep="\n"
|
||||
)
|
||||
sys_content += interleaved_content_as_str(existing_system_message.content, sep="\n")
|
||||
|
||||
messages.append(SystemMessage(content=sys_content))
|
||||
|
||||
|
|
|
@ -10,9 +10,7 @@ from typing import List, Optional, Protocol
|
|||
|
||||
class KVStore(Protocol):
|
||||
# TODO: make the value type bytes instead of str
|
||||
async def set(
|
||||
self, key: str, value: str, expiration: Optional[datetime] = None
|
||||
) -> None: ...
|
||||
async def set(self, key: str, value: str, expiration: Optional[datetime] = None) -> None: ...
|
||||
|
||||
async def get(self, key: str) -> Optional[str]: ...
|
||||
|
||||
|
|
|
@ -54,16 +54,11 @@ class SqliteKVStoreConfig(CommonConfig):
|
|||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(
|
||||
cls, __distro_dir__: str = "runtime", db_name: str = "kvstore.db"
|
||||
):
|
||||
def sample_run_config(cls, __distro_dir__: str = "runtime", db_name: str = "kvstore.db"):
|
||||
return {
|
||||
"type": "sqlite",
|
||||
"namespace": None,
|
||||
"db_path": "${env.SQLITE_STORE_DIR:~/.llama/"
|
||||
+ __distro_dir__
|
||||
+ "}/"
|
||||
+ db_name,
|
||||
"db_path": "${env.SQLITE_STORE_DIR:~/.llama/" + __distro_dir__ + "}/" + db_name,
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -28,11 +28,7 @@ class InmemoryKVStoreImpl(KVStore):
|
|||
self._store[key] = value
|
||||
|
||||
async def range(self, start_key: str, end_key: str) -> List[str]:
|
||||
return [
|
||||
self._store[key]
|
||||
for key in self._store.keys()
|
||||
if key >= start_key and key < end_key
|
||||
]
|
||||
return [self._store[key] for key in self._store.keys() if key >= start_key and key < end_key]
|
||||
|
||||
|
||||
async def kvstore_impl(config: KVStoreConfig) -> KVStore:
|
||||
|
|
|
@ -46,7 +46,6 @@ class PostgresKVStoreImpl(KVStore):
|
|||
"""
|
||||
)
|
||||
except Exception as e:
|
||||
|
||||
log.exception("Could not connect to PostgreSQL database server")
|
||||
raise RuntimeError("Could not connect to PostgreSQL database server") from e
|
||||
|
||||
|
@ -55,9 +54,7 @@ class PostgresKVStoreImpl(KVStore):
|
|||
return key
|
||||
return f"{self.config.namespace}:{key}"
|
||||
|
||||
async def set(
|
||||
self, key: str, value: str, expiration: Optional[datetime] = None
|
||||
) -> None:
|
||||
async def set(self, key: str, value: str, expiration: Optional[datetime] = None) -> None:
|
||||
key = self._namespaced_key(key)
|
||||
self.cursor.execute(
|
||||
f"""
|
||||
|
|
|
@ -25,9 +25,7 @@ class RedisKVStoreImpl(KVStore):
|
|||
return key
|
||||
return f"{self.config.namespace}:{key}"
|
||||
|
||||
async def set(
|
||||
self, key: str, value: str, expiration: Optional[datetime] = None
|
||||
) -> None:
|
||||
async def set(self, key: str, value: str, expiration: Optional[datetime] = None) -> None:
|
||||
key = self._namespaced_key(key)
|
||||
await self.redis.set(key, value)
|
||||
if expiration:
|
||||
|
@ -66,9 +64,7 @@ class RedisKVStoreImpl(KVStore):
|
|||
if matching_keys:
|
||||
values = await self.redis.mget(matching_keys)
|
||||
return [
|
||||
value.decode("utf-8") if isinstance(value, bytes) else value
|
||||
for value in values
|
||||
if value is not None
|
||||
value.decode("utf-8") if isinstance(value, bytes) else value for value in values if value is not None
|
||||
]
|
||||
|
||||
return []
|
||||
|
|
|
@ -34,9 +34,7 @@ class SqliteKVStoreImpl(KVStore):
|
|||
)
|
||||
await db.commit()
|
||||
|
||||
async def set(
|
||||
self, key: str, value: str, expiration: Optional[datetime] = None
|
||||
) -> None:
|
||||
async def set(self, key: str, value: str, expiration: Optional[datetime] = None) -> None:
|
||||
async with aiosqlite.connect(self.db_path) as db:
|
||||
await db.execute(
|
||||
f"INSERT OR REPLACE INTO {self.table_name} (key, value, expiration) VALUES (?, ?, ?)",
|
||||
|
@ -46,9 +44,7 @@ class SqliteKVStoreImpl(KVStore):
|
|||
|
||||
async def get(self, key: str) -> Optional[str]:
|
||||
async with aiosqlite.connect(self.db_path) as db:
|
||||
async with db.execute(
|
||||
f"SELECT value, expiration FROM {self.table_name} WHERE key = ?", (key,)
|
||||
) as cursor:
|
||||
async with db.execute(f"SELECT value, expiration FROM {self.table_name} WHERE key = ?", (key,)) as cursor:
|
||||
row = await cursor.fetchone()
|
||||
if row is None:
|
||||
return None
|
||||
|
|
|
@ -141,9 +141,7 @@ async def content_from_doc(doc: RAGDocument) -> str:
|
|||
return interleaved_content_as_str(doc.content)
|
||||
|
||||
|
||||
def make_overlapped_chunks(
|
||||
document_id: str, text: str, window_len: int, overlap_len: int
|
||||
) -> List[Chunk]:
|
||||
def make_overlapped_chunks(document_id: str, text: str, window_len: int, overlap_len: int) -> List[Chunk]:
|
||||
tokenizer = Tokenizer.get_instance()
|
||||
tokens = tokenizer.encode(text, bos=False, eos=False)
|
||||
|
||||
|
@ -171,9 +169,7 @@ class EmbeddingIndex(ABC):
|
|||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
async def query(
|
||||
self, embedding: NDArray, k: int, score_threshold: float
|
||||
) -> QueryChunksResponse:
|
||||
async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
|
@ -209,8 +205,6 @@ class VectorDBWithIndex:
|
|||
score_threshold = params.get("score_threshold", 0.0)
|
||||
|
||||
query_str = interleaved_content_as_str(query)
|
||||
embeddings_response = await self.inference_api.embeddings(
|
||||
self.vector_db.embedding_model, [query_str]
|
||||
)
|
||||
embeddings_response = await self.inference_api.embeddings(self.vector_db.embedding_model, [query_str])
|
||||
query_vector = np.array(embeddings_response.embeddings[0], dtype=np.float32)
|
||||
return await self.index.query(query_vector, k, score_threshold)
|
||||
|
|
|
@ -23,9 +23,7 @@ def aggregate_accuracy(scoring_results: List[ScoringResultRow]) -> Dict[str, Any
|
|||
|
||||
def aggregate_average(scoring_results: List[ScoringResultRow]) -> Dict[str, Any]:
|
||||
return {
|
||||
"average": sum(
|
||||
result["score"] for result in scoring_results if result["score"] is not None
|
||||
)
|
||||
"average": sum(result["score"] for result in scoring_results if result["score"] is not None)
|
||||
/ len([_ for _ in scoring_results if _["score"] is not None]),
|
||||
}
|
||||
|
||||
|
|
|
@ -70,9 +70,7 @@ class RegisteredBaseScoringFn(BaseScoringFn):
|
|||
|
||||
def register_scoring_fn_def(self, scoring_fn: ScoringFn) -> None:
|
||||
if scoring_fn.identifier in self.supported_fn_defs_registry:
|
||||
raise ValueError(
|
||||
f"Scoring function def with identifier {scoring_fn.identifier} already exists."
|
||||
)
|
||||
raise ValueError(f"Scoring function def with identifier {scoring_fn.identifier} already exists.")
|
||||
self.supported_fn_defs_registry[scoring_fn.identifier] = scoring_fn
|
||||
|
||||
@abstractmethod
|
||||
|
@ -98,11 +96,7 @@ class RegisteredBaseScoringFn(BaseScoringFn):
|
|||
params.aggregation_functions = scoring_params.aggregation_functions
|
||||
|
||||
aggregation_functions = []
|
||||
if (
|
||||
params
|
||||
and hasattr(params, "aggregation_functions")
|
||||
and params.aggregation_functions
|
||||
):
|
||||
if params and hasattr(params, "aggregation_functions") and params.aggregation_functions:
|
||||
aggregation_functions.extend(params.aggregation_functions)
|
||||
return aggregate_metrics(scoring_results, aggregation_functions)
|
||||
|
||||
|
@ -112,7 +106,4 @@ class RegisteredBaseScoringFn(BaseScoringFn):
|
|||
scoring_fn_identifier: Optional[str] = None,
|
||||
scoring_params: Optional[ScoringFnParams] = None,
|
||||
) -> List[ScoringResultRow]:
|
||||
return [
|
||||
await self.score_row(input_row, scoring_fn_identifier, scoring_params)
|
||||
for input_row in input_rows
|
||||
]
|
||||
return [await self.score_row(input_row, scoring_fn_identifier, scoring_params) for input_row in input_rows]
|
||||
|
|
|
@ -64,8 +64,7 @@ class TelemetryDatasetMixin:
|
|||
|
||||
for span in spans_by_id_resp.data.values():
|
||||
if span.attributes and all(
|
||||
attr in span.attributes and span.attributes[attr] is not None
|
||||
for attr in attributes_to_return
|
||||
attr in span.attributes and span.attributes[attr] is not None for attr in attributes_to_return
|
||||
):
|
||||
spans.append(
|
||||
Span(
|
||||
|
|
|
@ -118,10 +118,7 @@ class SQLiteTraceStore(TraceStore):
|
|||
# Build the attributes selection
|
||||
attributes_select = "s.attributes"
|
||||
if attributes_to_return:
|
||||
json_object = ", ".join(
|
||||
f"'{key}', json_extract(s.attributes, '$.{key}')"
|
||||
for key in attributes_to_return
|
||||
)
|
||||
json_object = ", ".join(f"'{key}', json_extract(s.attributes, '$.{key}')" for key in attributes_to_return)
|
||||
attributes_select = f"json_object({json_object})"
|
||||
|
||||
# SQLite CTE query with filtered attributes
|
||||
|
|
|
@ -45,16 +45,12 @@ def trace_protocol(cls: Type[T]) -> Type[T]:
|
|||
def create_span_context(self: Any, *args: Any, **kwargs: Any) -> tuple:
|
||||
class_name = self.__class__.__name__
|
||||
method_name = method.__name__
|
||||
span_type = (
|
||||
"async_generator" if is_async_gen else "async" if is_async else "sync"
|
||||
)
|
||||
span_type = "async_generator" if is_async_gen else "async" if is_async else "sync"
|
||||
sig = inspect.signature(method)
|
||||
param_names = list(sig.parameters.keys())[1:] # Skip 'self'
|
||||
combined_args = {}
|
||||
for i, arg in enumerate(args):
|
||||
param_name = (
|
||||
param_names[i] if i < len(param_names) else f"position_{i + 1}"
|
||||
)
|
||||
param_name = param_names[i] if i < len(param_names) else f"position_{i + 1}"
|
||||
combined_args[param_name] = serialize_value(arg)
|
||||
for k, v in kwargs.items():
|
||||
combined_args[str(k)] = serialize_value(v)
|
||||
|
@ -70,14 +66,10 @@ def trace_protocol(cls: Type[T]) -> Type[T]:
|
|||
return class_name, method_name, span_attributes
|
||||
|
||||
@wraps(method)
|
||||
async def async_gen_wrapper(
|
||||
self: Any, *args: Any, **kwargs: Any
|
||||
) -> AsyncGenerator:
|
||||
async def async_gen_wrapper(self: Any, *args: Any, **kwargs: Any) -> AsyncGenerator:
|
||||
from llama_stack.providers.utils.telemetry import tracing
|
||||
|
||||
class_name, method_name, span_attributes = create_span_context(
|
||||
self, *args, **kwargs
|
||||
)
|
||||
class_name, method_name, span_attributes = create_span_context(self, *args, **kwargs)
|
||||
|
||||
with tracing.span(f"{class_name}.{method_name}", span_attributes) as span:
|
||||
try:
|
||||
|
@ -92,9 +84,7 @@ def trace_protocol(cls: Type[T]) -> Type[T]:
|
|||
async def async_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
|
||||
from llama_stack.providers.utils.telemetry import tracing
|
||||
|
||||
class_name, method_name, span_attributes = create_span_context(
|
||||
self, *args, **kwargs
|
||||
)
|
||||
class_name, method_name, span_attributes = create_span_context(self, *args, **kwargs)
|
||||
|
||||
with tracing.span(f"{class_name}.{method_name}", span_attributes) as span:
|
||||
try:
|
||||
|
@ -109,9 +99,7 @@ def trace_protocol(cls: Type[T]) -> Type[T]:
|
|||
def sync_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
|
||||
from llama_stack.providers.utils.telemetry import tracing
|
||||
|
||||
class_name, method_name, span_attributes = create_span_context(
|
||||
self, *args, **kwargs
|
||||
)
|
||||
class_name, method_name, span_attributes = create_span_context(self, *args, **kwargs)
|
||||
|
||||
with tracing.span(f"{class_name}.{method_name}", span_attributes) as span:
|
||||
try:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue