Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tokenizer redesign for better model-specific feature support #1082

Merged
merged 18 commits into from
Jul 2, 2024

Conversation

RdoubleA
Copy link
Contributor

@RdoubleA RdoubleA commented Jun 12, 2024

Motivation

  • Our tokenizers are not model specific and are rather tied to the base tokenizer, e.g., SentencePiece (SP) or TikToken (TT). These both implement tokenize_messages but this is not configurable for model-specific special tokens
  • Both of these are highly coupled with llama2 and llama3. For example, all the special tokens for TikToken are llama3 specific and it is difficult to customize
  • This has led to Phi3 having its own tokenizer file that reimplements SentencePieceTokenizer, but with custom special tokens. Ideally, this should be easily composable by configuring the phi3 specific tokens with SentencePiece as the base
  • This problem will only get worse as foundation models add more capabilities such as tool calling to their respective tokenizers. For example, Mistral v3 now uses their own special tokens on top of SentencePiece, and we have no way to onboard this besides copying and reimplementing SentencePieceTokenizer. Redundancy is normally ok but not if it can be easily addressed with composability.
  • If a new base tokenizer is used besides SentencePiece or TikToken, we should be flexible enough to quickly implement the underlying new encode and decode while still maintaining our higher level APIs (tokenize_messages, etc)

Design proposal

Intuition: separate the two core APIs (encode/decode and tokenize_messages) that operate at different levels of abstraction:

  • Encode/decode live in the base token encoding layer (SP or TT) and does not need model specific logic. This should be pulled out so any model can quickly use SP, TT, or any other base token encoding for encode/decode.
  • tokenize_messages requires model specific special tokens placed in particular locations. Thus, there should be a model specific class that implements tokenize_messages

We can achieve the above with BaseTokenizer and ModelTokenizer.

  • BaseTokenizer is the base abstract interface for any base tokenization model (SP or TT) that implements encode and decode
  • ModelTokenizer is the base abstract interface for any model-specific tokenizer that implements tokenize_messages. All models will implement their own Tokenizer class based on this interface so they can control tokenize_messages logic

class BaseTokenizer(Protocol):
    """Abstract token encoding model"""

    def encode(self, text: str, **kwargs) -> List[int]:
        """
        Given a string, return the a list of token ids.
        """

    def decode(
        self, token_ids: List[int], include_special: bool = False, **kwargs
    ) -> str:
        """
        Given a list of token ids, return the decoded text.
        """
class ModelTokenizer(Protocol):
    """
    Abstract tokenizer that implements model specific special token logic in
    ``tokenize_message`` and ``tokenize_messages`` methods.
    """

    def tokenize_messages(
        self, messages: List[Message], **kwargs
    ) -> Tuple[List[int], List[bool]]:
        """
        Given a list of messages, return a list of tokens and list of masks for
        the concatenated and formatted messages.
        """
        pass

This means the SentencePieceTokenizer and TikTokenTokenizer will be refactored to separate out encode/decode and tokenize_messages logic.

class TikTokenBaseTokenizer(BaseTokenizer):
    def __init__(
        self,
        path: str,
        pattern: str,
        name: str,
        special_tokens: Dict[str, int],
    ):
        mergeable_ranks = load_tiktoken_bpe(path)
        self.tt_model = Encoding(
            name=name,
            pat_str=pattern,
            mergeable_ranks=mergeable_ranks,
            special_tokens=special_tokens,
        )
        ...

    def encode(
        self,
        text: str,
        add_bos: bool,
        add_eos: bool,
        allowed_special: Set[str] = set(),
    ) -> List[int]:
        ...

    def decode(
        self,
        token_ids: List[int],
        truncate_at_eos: bool = True,
    ) -> str:
        ...
class SentencePieceBaseTokenizer(BaseTokenizer):
    def __init__(
        self,
        path: str,
    ):
        spm_model = SentencePieceProcessor()
        spm_model.load(path)
        self.spm_model = spm_model
        # Special tokens are defined in the protobuf itself

    def encode(
        self,
        text: str,
        add_bos: bool,
        add_eos: bool,
        allowed_special: Set[str] = set(),
        trim_leading_whitespace: bool = False,
        prefix: Optional[str] = None,
    ) -> List[int]:
        ...

    def decode(
        self,
        token_ids: List[int],
    ) -> str:
        ...

And any model tokenizers would compose with the above classes:

class Llama3Tokenizer(ModelTokenizer):
    """
    tiktoken tokenizer configured with Llama3 Instruct's special tokens, as described in
    https://llama.meta.com/docs/model-cards-and-prompt-formats/meta-llama-3

    Args:
        path (str): Path to pretrained tiktoken tokenizer file.
    """

    def __init__(
        self,
        path: str,
    ):
        all_special_tokens_with_ids = self._get_all_special_tokens_with_ids()
        self.tt_model = TikTokenBaseTokenizer(
            path=path,
            name="llama3_tiktoken",
            pattern=CL100K_PATTERN,
            special_tokens=all_special_tokens_with_ids,
        )

    def tokenize_messages(
        self,
        messages: List[Message],
        max_seq_len: Optional[int] = None,
        tokenize_header: bool = True,
        add_eos: bool = True,
    ) -> Tuple[List[int], List[bool]]:
        ...

Changelog

  • Refactor SentencePieceTokenizer -> SentencePieceBaseTokenizer, TikTokenTokenizer -> TikTokenBaseTokenizer and only retain encode/decode logic
  • Move tokenize_messages logic to Llama2Tokenizer and Llama3Tokenizer for SP and TT, respectively
  • Create GemmaTokenizer and MistralTokenizer which leverage SentencePieceBaseTokenizer. The tokenize_messages logic is identical to Llama2Tokenizer (for now, Mistral needs to be updated with v3)
  • Refactor Phi3's tokenizer to use SentencePieceBaseTokenizer and retain special token logic
  • Revamp DummyTokenizer since it inherited previously from SentencePieceTokenizer
  • Update all tests and docstrings

Test plan

Planned follow-ups

  • Tokenizer live docs / API ref updates

Copy link

pytorch-bot bot commented Jun 12, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1082

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 93028cf with merge base 95ccf40 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jun 12, 2024
Copy link
Contributor

@kartikayk kartikayk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Want to express my uncertainty on the TokenEncoding naming. This seems a bit unintuitive since I'd imagine the encoding to refer to something like UTF8.

Not the best, but should we consider Tokenizer (and so SentencePieceTokenizer derives from Tokenizer) and ModelTokenizer (and so Llama3Tokenizer derives from ModelTokenizer). I don't love ModelTokenizer, but with the right doc strings I think it's passable.

@ebsmothers
Copy link
Contributor

Overall I like the proposal. Still going through all the code, but the division of encode and tokenize_messages APIs you've provided here definitely makes sense to me. One big question I have is whether we can do this in a way where we don't have to write an entirely new tokenizer class every time we onboard a new model. I get that if there's custom logic around special tokens we can't really avoid it, but for vanilla sentencepiece without much special formatting it'd be nice to make this process a bit more lightweight (e.g. Gemma and Llama2 tokenizers have the same tokenize_messages API, and imo copy-pasting this code is not as nice as than copy-pasting model builders, which are generally pretty straightforward to understand).

Also:

Move torchtune.modules.tokenizers to torchtune.data.tokenizers which is a more relevant location

I understand that this is a useful thing to do, but it makes reviewing this code harder and is kinda logically distinct from the important stuff in this PR (imo). With this move, I have to look at tokenizers in their entirety; without it I can actually see the diff.

@RdoubleA
Copy link
Contributor Author

Want to express my uncertainty on the TokenEncoding naming. This seems a bit unintuitive since I'd imagine the encoding to refer to something like UTF8.

Yeah I agree I don't like the TokenEncoding naming either, but naming SP and TT as Tokenizer and everything else as ModelTokenizer is more confusing imo... there should be a clear distinction between a model tokenizer and the base tokenizer. Maybe we could call it SentencePieceBaseTokenizer and TikTokenBaseTokenizer

@RdoubleA
Copy link
Contributor Author

@joecummings huh I guess Gemma does have special tokens, but our current tokenizer does not use them. what do you think about either punting the upgrade to later or parsing the special tokens from the HF json for now and then adding them in tokenize_messages as appropriate later? Same question for Mistral

@RdoubleA RdoubleA changed the title [RFC] Tokenizer redesign for better model-specific feature support Tokenizer redesign for better model-specific feature support Jun 25, 2024
@codecov-commenter
Copy link

Codecov Report

Attention: Patch coverage is 93.66516% with 28 lines in your changes missing coverage. Please review.

Project coverage is 67.45%. Comparing base (f200da5) to head (1d6e5e3).
Report is 1 commits behind head on main.

Files Patch % Lines
torchtune/modules/tokenizers/_utils.py 80.48% 8 Missing ⚠️
torchtune/models/llama3/_tokenizer.py 88.33% 7 Missing ⚠️
torchtune/models/gemma/_tokenizer.py 89.28% 3 Missing ⚠️
torchtune/models/llama2/_tokenizer.py 89.28% 3 Missing ⚠️
torchtune/models/mistral/_tokenizer.py 89.28% 3 Missing ⚠️
torchtune/models/phi3/_tokenizer.py 84.61% 2 Missing ⚠️
recipes/eleuther_eval.py 0.00% 1 Missing ⚠️
torchtune/modules/tokenizers/_tiktoken.py 96.55% 1 Missing ⚠️
Additional details and impacted files
@@             Coverage Diff             @@
##             main    #1082       +/-   ##
===========================================
+ Coverage   26.74%   67.45%   +40.71%     
===========================================
  Files         183      191        +8     
  Lines        8362     8498      +136     
===========================================
+ Hits         2236     5732     +3496     
+ Misses       6126     2766     -3360     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@joecummings
Copy link
Contributor

@joecummings huh I guess Gemma does have special tokens, but our current tokenizer does not use them. what do you think about either punting the upgrade to later or parsing the special tokens from the HF json for now and then adding them in tokenize_messages as appropriate later? Same question for Mistral

I opened an issue to address this later: #1118.

torchtune/datasets/_chat.py Outdated Show resolved Hide resolved
recipes/eleuther_eval.py Outdated Show resolved Hide resolved
torchtune/models/llama2/_tokenizer.py Show resolved Hide resolved
torchtune/models/llama2/_tokenizer.py Outdated Show resolved Hide resolved
torchtune/models/gemma/_tokenizer.py Show resolved Hide resolved
torchtune/models/llama3/_tokenizer.py Outdated Show resolved Hide resolved
Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A handful more comments but no huge concerns from my side. Accepting to unblock

docs/source/api_ref_modules.rst Show resolved Hide resolved
tests/torchtune/modules/tokenizers/test_sentencepiece.py Outdated Show resolved Hide resolved
torchtune/models/mistral/_tokenizer.py Show resolved Hide resolved
torchtune/models/phi3/_tokenizer.py Outdated Show resolved Hide resolved
torchtune/modules/tokenizers/_utils.py Outdated Show resolved Hide resolved
torchtune/modules/tokenizers/_utils.py Outdated Show resolved Hide resolved
torchtune/models/gemma/_tokenizer.py Show resolved Hide resolved
torchtune/models/llama3/_tokenizer.py Show resolved Hide resolved
Copy link
Contributor

@joecummings joecummings left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is amazing work, but I see a few things that definitely need to be documented before landing.

docs/source/tutorials/chat.rst Outdated Show resolved Hide resolved
torchtune/models/llama3/_tokenizer.py Show resolved Hide resolved
torchtune/models/mistral/__init__.py Show resolved Hide resolved
torchtune/modules/tokenizers/__init__.py Show resolved Hide resolved
@RdoubleA RdoubleA merged commit f158577 into pytorch:main Jul 2, 2024
29 checks passed
@RdoubleA RdoubleA deleted the tokenizer branch July 2, 2024 18:07
maximegmd pushed a commit to maximegmd/torchtune that referenced this pull request Jul 13, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants