mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-01 00:05:18 +00:00
use provider resource id to validate for models
This commit is contained in:
parent
e4f14eafe2
commit
95b7f57d92
7 changed files with 75 additions and 46 deletions
|
@ -226,7 +226,7 @@ class Inference(Protocol):
|
||||||
@webmethod(route="/inference/completion")
|
@webmethod(route="/inference/completion")
|
||||||
async def completion(
|
async def completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model_id: str,
|
||||||
content: InterleavedTextMedia,
|
content: InterleavedTextMedia,
|
||||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||||
response_format: Optional[ResponseFormat] = None,
|
response_format: Optional[ResponseFormat] = None,
|
||||||
|
@ -237,7 +237,7 @@ class Inference(Protocol):
|
||||||
@webmethod(route="/inference/chat_completion")
|
@webmethod(route="/inference/chat_completion")
|
||||||
async def chat_completion(
|
async def chat_completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model_id: str,
|
||||||
messages: List[Message],
|
messages: List[Message],
|
||||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||||
# zero-shot tool definitions as input to the model
|
# zero-shot tool definitions as input to the model
|
||||||
|
@ -254,6 +254,6 @@ class Inference(Protocol):
|
||||||
@webmethod(route="/inference/embeddings")
|
@webmethod(route="/inference/embeddings")
|
||||||
async def embeddings(
|
async def embeddings(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model_id: str,
|
||||||
contents: List[InterleavedTextMedia],
|
contents: List[InterleavedTextMedia],
|
||||||
) -> EmbeddingsResponse: ...
|
) -> EmbeddingsResponse: ...
|
||||||
|
|
|
@ -95,7 +95,7 @@ class InferenceRouter(Inference):
|
||||||
|
|
||||||
async def chat_completion(
|
async def chat_completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model_id: str,
|
||||||
messages: List[Message],
|
messages: List[Message],
|
||||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||||
response_format: Optional[ResponseFormat] = None,
|
response_format: Optional[ResponseFormat] = None,
|
||||||
|
@ -105,8 +105,9 @@ class InferenceRouter(Inference):
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
|
model = await self.routing_table.get_model(model_id)
|
||||||
params = dict(
|
params = dict(
|
||||||
model=model,
|
model_id=model.provider_resource_id,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
tools=tools or [],
|
tools=tools or [],
|
||||||
|
@ -116,7 +117,7 @@ class InferenceRouter(Inference):
|
||||||
stream=stream,
|
stream=stream,
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
)
|
)
|
||||||
provider = self.routing_table.get_provider_impl(model)
|
provider = self.routing_table.get_provider_impl(model_id)
|
||||||
if stream:
|
if stream:
|
||||||
return (chunk async for chunk in await provider.chat_completion(**params))
|
return (chunk async for chunk in await provider.chat_completion(**params))
|
||||||
else:
|
else:
|
||||||
|
@ -124,16 +125,17 @@ class InferenceRouter(Inference):
|
||||||
|
|
||||||
async def completion(
|
async def completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model_id: str,
|
||||||
content: InterleavedTextMedia,
|
content: InterleavedTextMedia,
|
||||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||||
response_format: Optional[ResponseFormat] = None,
|
response_format: Optional[ResponseFormat] = None,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
provider = self.routing_table.get_provider_impl(model)
|
model = await self.routing_table.get_model(model_id)
|
||||||
|
provider = self.routing_table.get_provider_impl(model_id)
|
||||||
params = dict(
|
params = dict(
|
||||||
model=model,
|
model_id=model.provider_resource_id,
|
||||||
content=content,
|
content=content,
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
response_format=response_format,
|
response_format=response_format,
|
||||||
|
@ -147,11 +149,12 @@ class InferenceRouter(Inference):
|
||||||
|
|
||||||
async def embeddings(
|
async def embeddings(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model_id: str,
|
||||||
contents: List[InterleavedTextMedia],
|
contents: List[InterleavedTextMedia],
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
return await self.routing_table.get_provider_impl(model).embeddings(
|
model = await self.routing_table.get_model(model_id)
|
||||||
model=model,
|
return await self.routing_table.get_provider_impl(model_id).embeddings(
|
||||||
|
model_id=model.provider_resource_id,
|
||||||
contents=contents,
|
contents=contents,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -74,7 +74,7 @@ class FireworksInferenceAdapter(
|
||||||
|
|
||||||
async def completion(
|
async def completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model_id: str,
|
||||||
content: InterleavedTextMedia,
|
content: InterleavedTextMedia,
|
||||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||||
response_format: Optional[ResponseFormat] = None,
|
response_format: Optional[ResponseFormat] = None,
|
||||||
|
@ -82,7 +82,7 @@ class FireworksInferenceAdapter(
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
request = CompletionRequest(
|
request = CompletionRequest(
|
||||||
model=model,
|
model=model_id,
|
||||||
content=content,
|
content=content,
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
response_format=response_format,
|
response_format=response_format,
|
||||||
|
@ -138,7 +138,7 @@ class FireworksInferenceAdapter(
|
||||||
|
|
||||||
async def chat_completion(
|
async def chat_completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model_id: str,
|
||||||
messages: List[Message],
|
messages: List[Message],
|
||||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||||
tools: Optional[List[ToolDefinition]] = None,
|
tools: Optional[List[ToolDefinition]] = None,
|
||||||
|
@ -149,7 +149,7 @@ class FireworksInferenceAdapter(
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
request = ChatCompletionRequest(
|
request = ChatCompletionRequest(
|
||||||
model=model,
|
model=model_id,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
tools=tools or [],
|
tools=tools or [],
|
||||||
|
@ -229,7 +229,7 @@ class FireworksInferenceAdapter(
|
||||||
|
|
||||||
async def embeddings(
|
async def embeddings(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model_id: str,
|
||||||
contents: List[InterleavedTextMedia],
|
contents: List[InterleavedTextMedia],
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
|
@ -63,7 +63,7 @@ class TogetherInferenceAdapter(
|
||||||
|
|
||||||
async def completion(
|
async def completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model_id: str,
|
||||||
content: InterleavedTextMedia,
|
content: InterleavedTextMedia,
|
||||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||||
response_format: Optional[ResponseFormat] = None,
|
response_format: Optional[ResponseFormat] = None,
|
||||||
|
@ -71,7 +71,7 @@ class TogetherInferenceAdapter(
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
request = CompletionRequest(
|
request = CompletionRequest(
|
||||||
model=model,
|
model=model_id,
|
||||||
content=content,
|
content=content,
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
response_format=response_format,
|
response_format=response_format,
|
||||||
|
@ -135,7 +135,7 @@ class TogetherInferenceAdapter(
|
||||||
|
|
||||||
async def chat_completion(
|
async def chat_completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model_id: str,
|
||||||
messages: List[Message],
|
messages: List[Message],
|
||||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||||
tools: Optional[List[ToolDefinition]] = None,
|
tools: Optional[List[ToolDefinition]] = None,
|
||||||
|
@ -146,7 +146,7 @@ class TogetherInferenceAdapter(
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
request = ChatCompletionRequest(
|
request = ChatCompletionRequest(
|
||||||
model=model,
|
model=model_id,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
tools=tools or [],
|
tools=tools or [],
|
||||||
|
@ -221,7 +221,7 @@ class TogetherInferenceAdapter(
|
||||||
|
|
||||||
async def embeddings(
|
async def embeddings(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model_id: str,
|
||||||
contents: List[InterleavedTextMedia],
|
contents: List[InterleavedTextMedia],
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
|
@ -142,6 +142,31 @@ def inference_bedrock() -> ProviderFixture:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_short_name(model_name: str) -> str:
|
||||||
|
"""Convert model name to a short test identifier.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name: Full model name like "Llama3.1-8B-Instruct"
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Short name like "llama_8b" suitable for test markers
|
||||||
|
"""
|
||||||
|
model_name = model_name.lower()
|
||||||
|
if "vision" in model_name:
|
||||||
|
return "llama_vision"
|
||||||
|
elif "3b" in model_name:
|
||||||
|
return "llama_3b"
|
||||||
|
elif "8b" in model_name:
|
||||||
|
return "llama_8b"
|
||||||
|
else:
|
||||||
|
return model_name.replace(".", "_").replace("-", "_")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def model_id(inference_model) -> str:
|
||||||
|
return get_model_short_name(inference_model)
|
||||||
|
|
||||||
|
|
||||||
INFERENCE_FIXTURES = [
|
INFERENCE_FIXTURES = [
|
||||||
"meta_reference",
|
"meta_reference",
|
||||||
"ollama",
|
"ollama",
|
||||||
|
@ -154,7 +179,7 @@ INFERENCE_FIXTURES = [
|
||||||
|
|
||||||
|
|
||||||
@pytest_asyncio.fixture(scope="session")
|
@pytest_asyncio.fixture(scope="session")
|
||||||
async def inference_stack(request, inference_model):
|
async def inference_stack(request, inference_model, model_id):
|
||||||
fixture_name = request.param
|
fixture_name = request.param
|
||||||
inference_fixture = request.getfixturevalue(f"inference_{fixture_name}")
|
inference_fixture = request.getfixturevalue(f"inference_{fixture_name}")
|
||||||
impls = await resolve_impls_for_test_v2(
|
impls = await resolve_impls_for_test_v2(
|
||||||
|
@ -163,7 +188,7 @@ async def inference_stack(request, inference_model):
|
||||||
inference_fixture.provider_data,
|
inference_fixture.provider_data,
|
||||||
models=[
|
models=[
|
||||||
ModelInput(
|
ModelInput(
|
||||||
model_id=inference_model,
|
model_id=model_id,
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
|
@ -64,7 +64,7 @@ def sample_tool_definition():
|
||||||
|
|
||||||
class TestInference:
|
class TestInference:
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_model_list(self, inference_model, inference_stack):
|
async def test_model_list(self, inference_model, inference_stack, model_id):
|
||||||
_, models_impl = inference_stack
|
_, models_impl = inference_stack
|
||||||
response = await models_impl.list_models()
|
response = await models_impl.list_models()
|
||||||
assert isinstance(response, list)
|
assert isinstance(response, list)
|
||||||
|
@ -73,17 +73,16 @@ class TestInference:
|
||||||
|
|
||||||
model_def = None
|
model_def = None
|
||||||
for model in response:
|
for model in response:
|
||||||
if model.identifier == inference_model:
|
if model.identifier == model_id:
|
||||||
model_def = model
|
model_def = model
|
||||||
break
|
break
|
||||||
|
|
||||||
assert model_def is not None
|
assert model_def is not None
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_completion(self, inference_model, inference_stack):
|
async def test_completion(self, inference_model, inference_stack, model_id):
|
||||||
inference_impl, _ = inference_stack
|
inference_impl, _ = inference_stack
|
||||||
|
provider = inference_impl.routing_table.get_provider_impl(model_id)
|
||||||
provider = inference_impl.routing_table.get_provider_impl(inference_model)
|
|
||||||
if provider.__provider_spec__.provider_type not in (
|
if provider.__provider_spec__.provider_type not in (
|
||||||
"meta-reference",
|
"meta-reference",
|
||||||
"remote::ollama",
|
"remote::ollama",
|
||||||
|
@ -96,7 +95,7 @@ class TestInference:
|
||||||
response = await inference_impl.completion(
|
response = await inference_impl.completion(
|
||||||
content="Micheael Jordan is born in ",
|
content="Micheael Jordan is born in ",
|
||||||
stream=False,
|
stream=False,
|
||||||
model=inference_model,
|
model_id=model_id,
|
||||||
sampling_params=SamplingParams(
|
sampling_params=SamplingParams(
|
||||||
max_tokens=50,
|
max_tokens=50,
|
||||||
),
|
),
|
||||||
|
@ -110,7 +109,7 @@ class TestInference:
|
||||||
async for r in await inference_impl.completion(
|
async for r in await inference_impl.completion(
|
||||||
content="Roses are red,",
|
content="Roses are red,",
|
||||||
stream=True,
|
stream=True,
|
||||||
model=inference_model,
|
model_id=model_id,
|
||||||
sampling_params=SamplingParams(
|
sampling_params=SamplingParams(
|
||||||
max_tokens=50,
|
max_tokens=50,
|
||||||
),
|
),
|
||||||
|
@ -125,11 +124,11 @@ class TestInference:
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.skip("This test is not quite robust")
|
@pytest.mark.skip("This test is not quite robust")
|
||||||
async def test_completions_structured_output(
|
async def test_completions_structured_output(
|
||||||
self, inference_model, inference_stack
|
self, inference_model, inference_stack, model_id
|
||||||
):
|
):
|
||||||
inference_impl, _ = inference_stack
|
inference_impl, _ = inference_stack
|
||||||
|
|
||||||
provider = inference_impl.routing_table.get_provider_impl(inference_model)
|
provider = inference_impl.routing_table.get_provider_impl(model_id)
|
||||||
if provider.__provider_spec__.provider_type not in (
|
if provider.__provider_spec__.provider_type not in (
|
||||||
"meta-reference",
|
"meta-reference",
|
||||||
"remote::tgi",
|
"remote::tgi",
|
||||||
|
@ -149,7 +148,7 @@ class TestInference:
|
||||||
response = await inference_impl.completion(
|
response = await inference_impl.completion(
|
||||||
content=user_input,
|
content=user_input,
|
||||||
stream=False,
|
stream=False,
|
||||||
model=inference_model,
|
model_id=model_id,
|
||||||
sampling_params=SamplingParams(
|
sampling_params=SamplingParams(
|
||||||
max_tokens=50,
|
max_tokens=50,
|
||||||
),
|
),
|
||||||
|
@ -167,11 +166,11 @@ class TestInference:
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_chat_completion_non_streaming(
|
async def test_chat_completion_non_streaming(
|
||||||
self, inference_model, inference_stack, common_params, sample_messages
|
self, inference_model, inference_stack, common_params, sample_messages, model_id
|
||||||
):
|
):
|
||||||
inference_impl, _ = inference_stack
|
inference_impl, _ = inference_stack
|
||||||
response = await inference_impl.chat_completion(
|
response = await inference_impl.chat_completion(
|
||||||
model=inference_model,
|
model_id=model_id,
|
||||||
messages=sample_messages,
|
messages=sample_messages,
|
||||||
stream=False,
|
stream=False,
|
||||||
**common_params,
|
**common_params,
|
||||||
|
@ -184,11 +183,11 @@ class TestInference:
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_structured_output(
|
async def test_structured_output(
|
||||||
self, inference_model, inference_stack, common_params
|
self, inference_model, inference_stack, common_params, model_id
|
||||||
):
|
):
|
||||||
inference_impl, _ = inference_stack
|
inference_impl, _ = inference_stack
|
||||||
|
|
||||||
provider = inference_impl.routing_table.get_provider_impl(inference_model)
|
provider = inference_impl.routing_table.get_provider_impl(model_id)
|
||||||
if provider.__provider_spec__.provider_type not in (
|
if provider.__provider_spec__.provider_type not in (
|
||||||
"meta-reference",
|
"meta-reference",
|
||||||
"remote::fireworks",
|
"remote::fireworks",
|
||||||
|
@ -204,7 +203,7 @@ class TestInference:
|
||||||
num_seasons_in_nba: int
|
num_seasons_in_nba: int
|
||||||
|
|
||||||
response = await inference_impl.chat_completion(
|
response = await inference_impl.chat_completion(
|
||||||
model=inference_model,
|
model_id=model_id,
|
||||||
messages=[
|
messages=[
|
||||||
SystemMessage(content="You are a helpful assistant."),
|
SystemMessage(content="You are a helpful assistant."),
|
||||||
UserMessage(content="Please give me information about Michael Jordan."),
|
UserMessage(content="Please give me information about Michael Jordan."),
|
||||||
|
@ -227,7 +226,7 @@ class TestInference:
|
||||||
assert answer.num_seasons_in_nba == 15
|
assert answer.num_seasons_in_nba == 15
|
||||||
|
|
||||||
response = await inference_impl.chat_completion(
|
response = await inference_impl.chat_completion(
|
||||||
model=inference_model,
|
model_id=model_id,
|
||||||
messages=[
|
messages=[
|
||||||
SystemMessage(content="You are a helpful assistant."),
|
SystemMessage(content="You are a helpful assistant."),
|
||||||
UserMessage(content="Please give me information about Michael Jordan."),
|
UserMessage(content="Please give me information about Michael Jordan."),
|
||||||
|
@ -244,13 +243,13 @@ class TestInference:
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_chat_completion_streaming(
|
async def test_chat_completion_streaming(
|
||||||
self, inference_model, inference_stack, common_params, sample_messages
|
self, inference_model, inference_stack, common_params, sample_messages, model_id
|
||||||
):
|
):
|
||||||
inference_impl, _ = inference_stack
|
inference_impl, _ = inference_stack
|
||||||
response = [
|
response = [
|
||||||
r
|
r
|
||||||
async for r in await inference_impl.chat_completion(
|
async for r in await inference_impl.chat_completion(
|
||||||
model=inference_model,
|
model_id=model_id,
|
||||||
messages=sample_messages,
|
messages=sample_messages,
|
||||||
stream=True,
|
stream=True,
|
||||||
**common_params,
|
**common_params,
|
||||||
|
@ -277,6 +276,7 @@ class TestInference:
|
||||||
common_params,
|
common_params,
|
||||||
sample_messages,
|
sample_messages,
|
||||||
sample_tool_definition,
|
sample_tool_definition,
|
||||||
|
model_id,
|
||||||
):
|
):
|
||||||
inference_impl, _ = inference_stack
|
inference_impl, _ = inference_stack
|
||||||
messages = sample_messages + [
|
messages = sample_messages + [
|
||||||
|
@ -286,7 +286,7 @@ class TestInference:
|
||||||
]
|
]
|
||||||
|
|
||||||
response = await inference_impl.chat_completion(
|
response = await inference_impl.chat_completion(
|
||||||
model=inference_model,
|
model_id=model_id,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
tools=[sample_tool_definition],
|
tools=[sample_tool_definition],
|
||||||
stream=False,
|
stream=False,
|
||||||
|
@ -316,6 +316,7 @@ class TestInference:
|
||||||
common_params,
|
common_params,
|
||||||
sample_messages,
|
sample_messages,
|
||||||
sample_tool_definition,
|
sample_tool_definition,
|
||||||
|
model_id,
|
||||||
):
|
):
|
||||||
inference_impl, _ = inference_stack
|
inference_impl, _ = inference_stack
|
||||||
messages = sample_messages + [
|
messages = sample_messages + [
|
||||||
|
@ -327,7 +328,7 @@ class TestInference:
|
||||||
response = [
|
response = [
|
||||||
r
|
r
|
||||||
async for r in await inference_impl.chat_completion(
|
async for r in await inference_impl.chat_completion(
|
||||||
model=inference_model,
|
model_id=model_id,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
tools=[sample_tool_definition],
|
tools=[sample_tool_definition],
|
||||||
stream=True,
|
stream=True,
|
||||||
|
|
|
@ -29,7 +29,7 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
|
||||||
return self.stack_to_provider_models_map[identifier]
|
return self.stack_to_provider_models_map[identifier]
|
||||||
|
|
||||||
async def register_model(self, model: Model) -> None:
|
async def register_model(self, model: Model) -> None:
|
||||||
if model.identifier not in self.stack_to_provider_models_map:
|
if model.provider_resource_id not in self.stack_to_provider_models_map:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unsupported model {model.identifier}. Supported models: {self.stack_to_provider_models_map.keys()}"
|
f"Unsupported model {model.provider_resource_id}. Supported models: {self.stack_to_provider_models_map.keys()}"
|
||||||
)
|
)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue