From 6e52b4fb4200b25af5b07a8529bde0347927b941 Mon Sep 17 00:00:00 2001 From: sakher Date: Mon, 24 Jun 2024 19:58:16 +0100 Subject: [PATCH] Fixed Anthropic ContentBlock - Replaced with TextBlock --- .../providers/anthropic/anthropic_solver.py | 4 +- .../anthropic/anthropic_solver_test.py | 37 ++++++------------- 2 files changed, 13 insertions(+), 28 deletions(-) diff --git a/evals/solvers/providers/anthropic/anthropic_solver.py b/evals/solvers/providers/anthropic/anthropic_solver.py index bb7fe50e24..e0cdbebc74 100644 --- a/evals/solvers/providers/anthropic/anthropic_solver.py +++ b/evals/solvers/providers/anthropic/anthropic_solver.py @@ -2,7 +2,7 @@ import anthropic from anthropic import Anthropic -from anthropic.types import ContentBlock, MessageParam, Usage +from anthropic.types import MessageParam, TextBlock, Usage from evals.record import record_sampling from evals.solvers.solver import Solver, SolverResult @@ -99,7 +99,7 @@ def _convert_msgs_to_anthropic_format(msgs: list[Message]) -> list[MessageParam] anth_msgs = [ MessageParam( role=oai_to_anthropic_role[msg.role], - content=[ContentBlock(text=msg.content, type="text")], + content=[TextBlock(text=msg.content, type="text")], ) for msg in msgs ] diff --git a/evals/solvers/providers/anthropic/anthropic_solver_test.py b/evals/solvers/providers/anthropic/anthropic_solver_test.py index 9ba8fb1470..864c67d073 100644 --- a/evals/solvers/providers/anthropic/anthropic_solver_test.py +++ b/evals/solvers/providers/anthropic/anthropic_solver_test.py @@ -1,14 +1,11 @@ import os + import pytest +from anthropic.types import MessageParam, TextBlock, Usage from evals.record import DummyRecorder +from evals.solvers.providers.anthropic.anthropic_solver import AnthropicSolver, anth_to_openai_usage from evals.task_state import Message, TaskState -from evals.solvers.providers.anthropic.anthropic_solver import ( - AnthropicSolver, - anth_to_openai_usage, -) - -from anthropic.types import ContentBlock, MessageParam, Usage IN_GITHUB_ACTIONS = os.getenv("GITHUB_ACTIONS") == "true" MODEL_NAME = "claude-instant-1.2" @@ -32,9 +29,7 @@ def dummy_recorder(): yield recorder -@pytest.mark.skipif( - IN_GITHUB_ACTIONS, reason="API tests are wasteful to run on every commit." -) +@pytest.mark.skipif(IN_GITHUB_ACTIONS, reason="API tests are wasteful to run on every commit.") def test_solver(dummy_recorder, anthropic_solver): """ Test that the solver generates a response coherent with the message history @@ -55,9 +50,7 @@ def test_solver(dummy_recorder, anthropic_solver): ) solver_res = solver(task_state=task_state) - assert ( - solver_res.output == answer - ), f"Expected '{answer}', but got {solver_res.output}" + assert solver_res.output == answer, f"Expected '{answer}', but got {solver_res.output}" def test_message_format(): @@ -71,9 +64,7 @@ def test_message_format(): msgs = [ Message(role="user", content="What is 2 + 2?"), Message(role="system", content="reason step by step"), - Message( - role="assistant", content="I don't need to reason for this, 2+2 is just 4" - ), + Message(role="assistant", content="I don't need to reason for this, 2+2 is just 4"), Message(role="system", content="now, given your reasoning, provide the answer"), ] anth_msgs = AnthropicSolver._convert_msgs_to_anthropic_format(msgs) @@ -82,24 +73,20 @@ def test_message_format(): MessageParam( role="user", content=[ - ContentBlock(text="What is 2 + 2?", type="text"), - ContentBlock(text="reason step by step", type="text"), + TextBlock(text="What is 2 + 2?", type="text"), + TextBlock(text="reason step by step", type="text"), ], ), MessageParam( role="assistant", content=[ - ContentBlock( - text="I don't need to reason for this, 2+2 is just 4", type="text" - ), + TextBlock(text="I don't need to reason for this, 2+2 is just 4", type="text"), ], ), MessageParam( role="user", content=[ - ContentBlock( - text="now, given your reasoning, provide the answer", type="text" - ), + TextBlock(text="now, given your reasoning, provide the answer", type="text"), ], ), ] @@ -126,6 +113,4 @@ def test_anth_to_openai_usage_zero_tokens(): "prompt_tokens": 0, "total_tokens": 0, } - assert ( - anth_to_openai_usage(usage) == expected - ), "Zero token cases are not handled correctly." + assert anth_to_openai_usage(usage) == expected, "Zero token cases are not handled correctly."