mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-16 19:59:26 +00:00
rebase and fix some small breakage due to model -> model_id fix
This commit is contained in:
parent
22aedd0277
commit
1cb42d3060
6 changed files with 20 additions and 11 deletions
|
|
@ -38,15 +38,15 @@ async def register_object_with_provider(obj: RoutableObject, p: Any) -> Routable
|
||||||
if api == Api.inference:
|
if api == Api.inference:
|
||||||
return await p.register_model(obj)
|
return await p.register_model(obj)
|
||||||
elif api == Api.safety:
|
elif api == Api.safety:
|
||||||
await p.register_shield(obj)
|
return await p.register_shield(obj)
|
||||||
elif api == Api.memory:
|
elif api == Api.memory:
|
||||||
await p.register_memory_bank(obj)
|
return await p.register_memory_bank(obj)
|
||||||
elif api == Api.datasetio:
|
elif api == Api.datasetio:
|
||||||
await p.register_dataset(obj)
|
return await p.register_dataset(obj)
|
||||||
elif api == Api.scoring:
|
elif api == Api.scoring:
|
||||||
await p.register_scoring_function(obj)
|
return await p.register_scoring_function(obj)
|
||||||
elif api == Api.eval:
|
elif api == Api.eval:
|
||||||
await p.register_eval_task(obj)
|
return await p.register_eval_task(obj)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown API {api} for registering object with provider")
|
raise ValueError(f"Unknown API {api} for registering object with provider")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -234,7 +234,7 @@ class LlamaGuardShield:
|
||||||
# TODO: llama-stack inference protocol has issues with non-streaming inference code
|
# TODO: llama-stack inference protocol has issues with non-streaming inference code
|
||||||
content = ""
|
content = ""
|
||||||
async for chunk in await self.inference_api.chat_completion(
|
async for chunk in await self.inference_api.chat_completion(
|
||||||
model=self.model,
|
model_id=self.model,
|
||||||
messages=[shield_input_message],
|
messages=[shield_input_message],
|
||||||
stream=True,
|
stream=True,
|
||||||
):
|
):
|
||||||
|
|
|
||||||
|
|
@ -164,7 +164,6 @@ class OllamaInferenceAdapter(Inference, ModelRegistryHelper, ModelsProtocolPriva
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
model = await self.model_store.get_model(model_id)
|
model = await self.model_store.get_model(model_id)
|
||||||
print(f"model={model}")
|
|
||||||
request = ChatCompletionRequest(
|
request = ChatCompletionRequest(
|
||||||
model=model.provider_resource_id,
|
model=model.provider_resource_id,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
|
|
|
||||||
|
|
@ -46,6 +46,16 @@ DEFAULT_PROVIDER_COMBINATIONS = [
|
||||||
id="together",
|
id="together",
|
||||||
marks=pytest.mark.together,
|
marks=pytest.mark.together,
|
||||||
),
|
),
|
||||||
|
pytest.param(
|
||||||
|
{
|
||||||
|
"inference": "fireworks",
|
||||||
|
"safety": "llama_guard",
|
||||||
|
"memory": "faiss",
|
||||||
|
"agents": "meta_reference",
|
||||||
|
},
|
||||||
|
id="fireworks",
|
||||||
|
marks=pytest.mark.fireworks,
|
||||||
|
),
|
||||||
pytest.param(
|
pytest.param(
|
||||||
{
|
{
|
||||||
"inference": "remote",
|
"inference": "remote",
|
||||||
|
|
@ -60,7 +70,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
|
||||||
|
|
||||||
|
|
||||||
def pytest_configure(config):
|
def pytest_configure(config):
|
||||||
for mark in ["meta_reference", "ollama", "together", "remote"]:
|
for mark in ["meta_reference", "ollama", "together", "fireworks", "remote"]:
|
||||||
config.addinivalue_line(
|
config.addinivalue_line(
|
||||||
"markers",
|
"markers",
|
||||||
f"{mark}: marks tests as {mark} specific",
|
f"{mark}: marks tests as {mark} specific",
|
||||||
|
|
|
||||||
|
|
@ -147,9 +147,9 @@ class TestInference:
|
||||||
|
|
||||||
user_input = "Michael Jordan was born in 1963. He played basketball for the Chicago Bulls. He retired in 2003."
|
user_input = "Michael Jordan was born in 1963. He played basketball for the Chicago Bulls. He retired in 2003."
|
||||||
response = await inference_impl.completion(
|
response = await inference_impl.completion(
|
||||||
|
model_id=inference_model,
|
||||||
content=user_input,
|
content=user_input,
|
||||||
stream=False,
|
stream=False,
|
||||||
model=inference_model,
|
|
||||||
sampling_params=SamplingParams(
|
sampling_params=SamplingParams(
|
||||||
max_tokens=50,
|
max_tokens=50,
|
||||||
),
|
),
|
||||||
|
|
|
||||||
|
|
@ -55,7 +55,7 @@ class TestVisionModelInference:
|
||||||
)
|
)
|
||||||
|
|
||||||
response = await inference_impl.chat_completion(
|
response = await inference_impl.chat_completion(
|
||||||
model=inference_model,
|
model_id=inference_model,
|
||||||
messages=[
|
messages=[
|
||||||
UserMessage(content="You are a helpful assistant."),
|
UserMessage(content="You are a helpful assistant."),
|
||||||
UserMessage(content=[image, "Describe this image in two sentences."]),
|
UserMessage(content=[image, "Describe this image in two sentences."]),
|
||||||
|
|
@ -102,7 +102,7 @@ class TestVisionModelInference:
|
||||||
response = [
|
response = [
|
||||||
r
|
r
|
||||||
async for r in await inference_impl.chat_completion(
|
async for r in await inference_impl.chat_completion(
|
||||||
model=inference_model,
|
model_id=inference_model,
|
||||||
messages=[
|
messages=[
|
||||||
UserMessage(content="You are a helpful assistant."),
|
UserMessage(content="You are a helpful assistant."),
|
||||||
UserMessage(
|
UserMessage(
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue