From 2644bd0d6cea677f80e44ed4a44bea5e04aabeb3 Mon Sep 17 00:00:00 2001 From: Matthew Soulanille Date: Wed, 18 Dec 2024 17:57:50 -0800 Subject: [PATCH] [WebGPU] Make dataToGPU upload to GPU if data is on CPU (#8483) --- tfjs-backend-webgpu/src/backend_webgpu.ts | 9 ++++++--- tfjs-backend-webgpu/src/backend_webgpu_test.ts | 12 ++++++++++++ 2 files changed, 18 insertions(+), 3 deletions(-) diff --git a/tfjs-backend-webgpu/src/backend_webgpu.ts b/tfjs-backend-webgpu/src/backend_webgpu.ts index ceae66c513a..ebf517fa550 100644 --- a/tfjs-backend-webgpu/src/backend_webgpu.ts +++ b/tfjs-backend-webgpu/src/backend_webgpu.ts @@ -594,8 +594,9 @@ export class WebGPUBackend extends KernelBackend { * @param dataId The source tensor. */ override readToGPU(dataId: DataId): GPUData { - const srcTensorData = this.tensorMap.get(dataId); - const {values, dtype, shape, resource} = srcTensorData; + let srcTensorData = this.tensorMap.get(dataId); + const {values, dtype, shape} = srcTensorData; + let resource = srcTensorData.resource; if (dtype === 'complex64') { throw new Error('Does not support reading buffer for complex64 dtype.'); @@ -603,7 +604,9 @@ export class WebGPUBackend extends KernelBackend { if (resource == null) { if (values != null) { - throw new Error('Data is not on GPU but on CPU.'); + this.uploadToGPU(dataId); + srcTensorData = this.tensorMap.get(dataId); + resource = srcTensorData.resource; } else { throw new Error('There is no data on GPU or CPU.'); } diff --git a/tfjs-backend-webgpu/src/backend_webgpu_test.ts b/tfjs-backend-webgpu/src/backend_webgpu_test.ts index ed8149f409a..5e06905d7c8 100644 --- a/tfjs-backend-webgpu/src/backend_webgpu_test.ts +++ b/tfjs-backend-webgpu/src/backend_webgpu_test.ts @@ -200,6 +200,18 @@ describeWebGPU('backend webgpu', () => { await c3.data(); tf.env().set('WEBGPU_DEFERRED_SUBMIT_BATCH_SIZE', savedFlag); }); + + it('dataToGPU uploads to GPU if the tensor is on CPU', async () => { + const webGPUBackend = (tf.backend() as WebGPUBackend); + const data = [1,2,3,4,5]; + const tensor = tf.tensor1d(data); + const res = tensor.dataToGPU(); + expect(res.buffer).toBeDefined(); + const resData = await webGPUBackend.getBufferData(res.buffer); + const values = tf.util.convertBackendValuesAndArrayBuffer( + resData, res.tensorRef.dtype); + expectArraysEqual(values, data); + }); }); describeWebGPU('backendWebGPU', () => {