Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

HunyuanVideo w. BitsAndBytes (local): Expected all tensors to be on the same device #10500

Open
tin2tin opened this issue Jan 8, 2025 · 8 comments
Labels
bug Something isn't working

Comments

@tin2tin
Copy link

tin2tin commented Jan 8, 2025

Describe the bug

Errors in the HunyuanVideo examples here:
hunyuan_video

Reproduction

Run this code from the link:

import torch
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig, HunyuanVideoTransformer3DModel, HunyuanVideoPipeline
from diffusers.utils import export_to_video

quant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True)
transformer_8bit = HunyuanVideoTransformer3DModel.from_pretrained(
    "tencent/HunyuanVideo",
    subfolder="transformer",
    quantization_config=quant_config,
    torch_dtype=torch.float16,
)

pipeline = HunyuanVideoPipeline.from_pretrained(
    "tencent/HunyuanVideo",
    transformer=transformer_8bit,
    torch_dtype=torch.float16,
    device_map="balanced",
)

prompt = "A cat walks on the grass, realistic style."
video = pipeline(prompt=prompt, num_frames=61, num_inference_steps=30).frames[0]
export_to_video(video, "cat.mp4", fps=15)

Gives this error:
HTTPError: 404 Client Error: Not Found for url: https://huggingface.co/tencent/HunyuanVideo/resolve/main/transformer/config.json

Changing the path to: hunyuanvideo-community/HunyuanVideo

Gives this error:
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument index in method wrapper_CUDA__index_select)

And the other example crashes on a RTX 4090 due to OOM.

(I wanted to check if FastHunyuan-diffusers would be more vram friendly, but I couldn't due to those errors)

Logs

Logs inserted above.

System Info

Win 11

Who can help?

@DN6 @a-r-r-o-w

@tin2tin tin2tin added the bug Something isn't working label Jan 8, 2025
@tin2tin tin2tin changed the title Errors in the HunyuanVideo examples Errors in the HunyuanVideo examples/inference code Jan 8, 2025
@SahilCarterr
Copy link
Contributor

Add revision='refs/pr/18' to from_pretrained until weights are merged.

import torch
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig, HunyuanVideoTransformer3DModel, HunyuanVideoPipeline
from diffusers.utils import export_to_video

quant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True)
transformer_8bit = HunyuanVideoTransformer3DModel.from_pretrained(
    "tencent/HunyuanVideo",
    subfolder="transformer",
    quantization_config=quant_config,
    torch_dtype=torch.float16,
    revision='refs/pr/18'
)

pipeline = HunyuanVideoPipeline.from_pretrained(
    "tencent/HunyuanVideo",
    transformer=transformer_8bit,
    torch_dtype=torch.float16,
    device_map="balanced",
    revision='refs/pr/18'
)

prompt = "A cat walks on the grass, realistic style."
video = pipeline(prompt=prompt, num_frames=61, num_inference_steps=30).frames[0]
export_to_video(video, "cat.mp4", fps=15)

@tin2tin

@tin2tin
Copy link
Author

tin2tin commented Jan 8, 2025

@SahilCarterr
Thank you. With your code/weights, I still get this error (running it local):
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument index in method wrapper_CUDA__index_select)

@a-r-r-o-w
Copy link
Member

a-r-r-o-w commented Jan 8, 2025

@tin2tin, this seems to run perfectly fine for me without errors:

import torch
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig, HunyuanVideoTransformer3DModel, HunyuanVideoPipeline
from diffusers.utils import export_to_video

quant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True)
transformer_8bit = HunyuanVideoTransformer3DModel.from_pretrained(
    "hunyuanvideo-community/HunyuanVideo",
    subfolder="transformer",
    quantization_config=quant_config,
    torch_dtype=torch.bfloat16,
)

pipeline = HunyuanVideoPipeline.from_pretrained(
    "hunyuanvideo-community/HunyuanVideo",
    transformer=transformer_8bit,
    torch_dtype=torch.float16,
    device_map="balanced",
)
print(pipeline.text_encoder.device)
print(pipeline.transformer.device)
print(pipeline.vae.device)

prompt = "A cat walks on the grass, realistic style."
video = pipeline(prompt=prompt, num_frames=61, num_inference_steps=30).frames[0]
export_to_video(video, "cat.mp4", fps=15)

The OOM comes from the default height and width set to 720 and 1280. This requires roughly about 64gb with 61 frames. BnB quantization may not be of help here because the transformer must be in bfloat16 for decent results. To enable more consumer friendly inference, I recommend trying out:

  • enable_model_cpu_offload
  • torchao quantization on transformer, because it does support bf16
  • or, instead of quantization, we can do layerwise fp8 upcasting: [core] Layerwise Upcasting #10347

Note that you might still OOM despite these changes due to the resolution being used. Even if it doesn't, you need sufficient CPU RAM to be able to hold the models during offloading. If it OOMs on CPU, you could load the transformer directly in float8_e4m3fn and then enable layerwise upcasting - I believe this is similar to what makes it runnable in low vram in UIs

@tin2tin tin2tin changed the title Errors in the HunyuanVideo examples/inference code HunyuanVideo: Expected all tensors to be on the same device Jan 9, 2025
@tin2tin tin2tin changed the title HunyuanVideo: Expected all tensors to be on the same device HunyuanVideo w. BitsAndBytes: Expected all tensors to be on the same device Jan 9, 2025
@tin2tin tin2tin changed the title HunyuanVideo w. BitsAndBytes: Expected all tensors to be on the same device HunyuanVideo w. BitsAndBytes (local): Expected all tensors to be on the same device Jan 9, 2025
@tin2tin
Copy link
Author

tin2tin commented Jan 9, 2025

It's properly caused by the text encoder is on cpu and the rest is on cuda?

text_encoder: cpu
transformer: cuda
vae: cuda

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument index in method wrapper_CUDA__index_select

Fetching 6 files: 100%|██████████████████████████████████████████████████████████████████████████| 6/6 [00:00<?, ?it/s]
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████| 4/4 [00:01<00:00,  3.80it/s]
Loading pipeline components...: 100%|████████████████████████████████████████████████████| 7/7 [00:03<00:00,  1.77it/s]
cpu
cuda:0
cuda:0
Error: Python: Traceback (most recent call last):
  File "C:\Users\peter\Documents\blender-4.4.0\4.4\python\Lib\site-packages\torch\utils\_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\peter\Documents\blender-4.4.0\4.4\python\Lib\site-packages\diffusers\pipelines\hunyuan_video\pipeline_hunyuan_video.py", line 589, in __call__
    prompt_embeds, pooled_prompt_embeds, prompt_attention_mask = self.encode_prompt(
                                                                 ^^^^^^^^^^^^^^^^^^^
  File "C:\Users\peter\Documents\blender-4.4.0\4.4\python\Lib\site-packages\diffusers\pipelines\hunyuan_video\pipeline_hunyuan_video.py", line 314, in encode_prompt
    prompt_embeds, prompt_attention_mask = self._get_llama_prompt_embeds(
                                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\peter\Documents\blender-4.4.0\4.4\python\Lib\site-packages\diffusers\pipelines\hunyuan_video\pipeline_hunyuan_video.py", line 241, in _get_llama_prompt_embeds
    prompt_embeds = self.text_encoder(
                    ^^^^^^^^^^^^^^^^^^
  File "C:\Users\peter\Documents\blender-4.4.0\4.4\python\Lib\site-packages\torch\nn\modules\module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\peter\Documents\blender-4.4.0\4.4\python\Lib\site-packages\torch\nn\modules\module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\peter\Documents\blender-4.4.0\4.4\python\Lib\site-packages\transformers\models\llama\modeling_llama.py", line 891, in forward
    inputs_embeds = self.embed_tokens(input_ids)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\peter\Documents\blender-4.4.0\4.4\python\Lib\site-packages\torch\nn\modules\module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\peter\Documents\blender-4.4.0\4.4\python\Lib\site-packages\torch\nn\modules\module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\peter\Documents\blender-4.4.0\4.4\python\Lib\site-packages\torch\nn\modules\sparse.py", line 163, in forward
    return F.embedding(
           ^^^^^^^^^^^^
  File "C:\Users\peter\Documents\blender-4.4.0\4.4\python\Lib\site-packages\torch\nn\functional.py", line 2264, in embedding
    return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument index in method wrapper_CUDA__index_select)

@a-r-r-o-w
Copy link
Member

cc @sayakpaul for device_map

Oh my bad, this is because everything was on cuda for me due to testing on A100 😅 Text encoder should've been automatically moved to CUDA when its forward method is called in your case. Will try on lower VRAM GPU and try to replicate

@sayakpaul
Copy link
Member

Well, the text encoder is being placed on a CPU because with the "balanced" device_map there was nothing else available.

So, you could first compute the text embeddings and completely delete the text encoders to free up space. https://huggingface.co/docs/diffusers/main/en/api/pipelines/mochi#reproducing-the-results-from-the-genmo-mochi-repo shows an example of how encode_prompt() can be leveraged in this case.

@tin2tin
Copy link
Author

tin2tin commented Jan 9, 2025

(The Moshi code is causing OOM crash on RTX 4090)

@sayakpaul
Copy link
Member

I provided that example as a reference for you to adapt in your use case.

shows an example of how encode_prompt() can be leveraged in this case.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

4 participants