Skip to content

Commit

Permalink
Make kv_cache Optional (#1207)
Browse files Browse the repository at this point in the history
  • Loading branch information
pbontrager authored Jul 23, 2024
1 parent a8cee18 commit 7f8bf88
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 3 deletions.
3 changes: 2 additions & 1 deletion recipes/configs/generation.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,12 @@ tokenizer:
path: /tmp/Llama-2-7b-hf/tokenizer.model

# Generation arguments; defaults taken from gpt-fast
prompt: "Hello, my name is"
prompt: "Tell me a joke?"
instruct_template: null
chat_format: null
max_new_tokens: 300
temperature: 0.6 # 0.8 and 0.6 are popular values to try
top_k: 300
enable_kv_cache: True

quantizer: null
7 changes: 5 additions & 2 deletions recipes/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,15 @@ def setup(self, cfg: DictConfig) -> None:
self._model = self._setup_model(
model_cfg=cfg.model,
model_state_dict=ckpt_dict[utils.MODEL_KEY],
enable_kv_cache=cfg.enable_kv_cache,
)
self._tokenizer = config.instantiate(cfg.tokenizer)

def _setup_model(
self,
model_cfg: DictConfig,
model_state_dict: Dict[str, Any],
enable_kv_cache: bool = True,
) -> nn.Module:
with utils.set_default_dtype(self._dtype), self._device:
model = config.instantiate(model_cfg)
Expand All @@ -77,8 +79,9 @@ def _setup_model(
logger.info(f"Model is initialized with precision {self._dtype}.")

# Ensure the cache is setup on the right device
with self._device:
model.setup_caches(batch_size=1, dtype=self._dtype)
if enable_kv_cache:
with self._device:
model.setup_caches(batch_size=1, dtype=self._dtype)

return model

Expand Down

0 comments on commit 7f8bf88

Please sign in to comment.