Skip to content

Commit

Permalink
fixing tests
Browse files Browse the repository at this point in the history
  • Loading branch information
SalmanMohammadi committed Sep 16, 2024
1 parent 726abb0 commit dacd80e
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 3 deletions.
3 changes: 2 additions & 1 deletion recipes/eleuther_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ def _model_generate(
# Technically this is not necessary, but it's a good way to ensure that
# the caches won't error on a different batch size. In addition, caches
# are not needed for a regular model call, so we just setup here
# TODO @joecummings this is being called multiple times resulting in many WARNINGs
if self.enable_kv_cache:
with context.device:
self._model.setup_caches(batch_size=curr_batch_size, dtype=self._dtype)
Expand All @@ -154,7 +155,7 @@ def _model_generate(
"``do_sample`` for generation tasks is not supported yet in torchtune."
)

toks = generation.generate(
toks, _ = generation.generate(
self._model,
context,
max_generated_tokens=self.max_gen_toks,
Expand Down
13 changes: 11 additions & 2 deletions tests/recipes/test_eleuther_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,19 @@


class TestEleutherEval:
@pytest.mark.parametrize(
"eval_name, expected_acc, bsz",
[("truthfulqa_gen", 0.1, 1), ("truthfulqa_mc2", 0.3, 8)],
)
@pytest.mark.integration_test
def test_torchtune_checkpoint_eval_results(self, capsys, monkeypatch, tmpdir):
def test_torchtune_checkpoint_eval_results(
self, capsys, monkeypatch, tmpdir, eval_name, expected_acc, bsz
):
ckpt = "llama2_tune"
ckpt_path = Path(CKPT_MODEL_PATHS[ckpt])
ckpt_dir = ckpt_path.parent

# TODO @joecummings bsz > 1 isn't supported for generation tasks, update test once integrated
cmd = f"""
tune run eleuther_eval \
--config eleuther_evaluation \
Expand All @@ -39,6 +46,8 @@ def test_torchtune_checkpoint_eval_results(self, capsys, monkeypatch, tmpdir):
limit=10 \
dtype=fp32 \
device=cpu \
tasks=[{eval_name}]\
batch_size={bsz} \
""".split()

model_config = llama2_test_config()
Expand Down Expand Up @@ -66,7 +75,7 @@ def test_torchtune_checkpoint_eval_results(self, capsys, monkeypatch, tmpdir):
)
assert search_results is not None
acc_result = float(search_results.group(1))
assert math.isclose(acc_result, 0.3, abs_tol=0.05)
assert math.isclose(acc_result, expected_acc, abs_tol=0.05)

@pytest.fixture
def hide_available_pkg(self, monkeypatch):
Expand Down

0 comments on commit dacd80e

Please sign in to comment.