Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add codellama to tokenizer list for set_pad_token #3598

Merged
merged 2 commits into from
Sep 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion ludwig/utils/llm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from bitsandbytes.nn.modules import Embedding
from transformers import (
AutoModelForCausalLM,
CodeLlamaTokenizer,
CodeLlamaTokenizerFast,
GPT2Tokenizer,
GPT2TokenizerFast,
LlamaTokenizer,
Expand Down Expand Up @@ -40,7 +42,17 @@ def set_pad_token(tokenizer: PreTrainedTokenizer):
# These recommend using eos tokens instead
# https://github.com/huggingface/transformers/issues/2648#issuecomment-616177044
# https://github.com/huggingface/transformers/issues/2630#issuecomment-1290809338
if any(isinstance(tokenizer, t) for t in [GPT2Tokenizer, GPT2TokenizerFast, LlamaTokenizer, LlamaTokenizerFast]):
if any(
isinstance(tokenizer, t)
for t in [
GPT2Tokenizer,
GPT2TokenizerFast,
LlamaTokenizer,
LlamaTokenizerFast,
CodeLlamaTokenizer,
CodeLlamaTokenizerFast,
]
):
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id

Expand Down
18 changes: 16 additions & 2 deletions ludwig/utils/tokenizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -810,7 +810,14 @@ def get_unk_token(self) -> str:
def _set_pad_token(self) -> None:
"""Sets the pad token and pad token ID for the tokenizer."""

from transformers import GPT2Tokenizer, GPT2TokenizerFast, LlamaTokenizer, LlamaTokenizerFast
from transformers import (
CodeLlamaTokenizer,
CodeLlamaTokenizerFast,
GPT2Tokenizer,
GPT2TokenizerFast,
LlamaTokenizer,
LlamaTokenizerFast,
)

# Tokenizers might have the pad token id attribute since they tend to use the same base class, but
# it can be set to None so we check for this explicitly.
Expand All @@ -822,7 +829,14 @@ def _set_pad_token(self) -> None:
# https://github.com/huggingface/transformers/issues/2648#issuecomment-616177044
if any(
isinstance(self.tokenizer, t)
for t in [GPT2Tokenizer, GPT2TokenizerFast, LlamaTokenizer, LlamaTokenizerFast]
for t in [
GPT2Tokenizer,
GPT2TokenizerFast,
LlamaTokenizer,
LlamaTokenizerFast,
CodeLlamaTokenizer,
CodeLlamaTokenizerFast,
]
):
if hasattr(self.tokenizer, "eos_token") and self.tokenizer.eos_token is not None:
logger.warning("No padding token id found. Using eos_token as pad_token.")
Expand Down