PyTorch: saat memuat data batch menggunakan Dataloader, cara mentransfer data ke GPU secara otomatis

Jika kita menggunakan kombinasi kelas Dataset dan Dataloader (seperti yang ditunjukkan di bawah), saya harus memuat data secara eksplisit ke GPU menggunakan .to() atau .cuda(). Apakah ada cara untuk menginstruksikan pemuat data untuk melakukannya secara otomatis/implisit?

Kode untuk memahami/mereproduksi skenario:

from torch.utils.data import Dataset, DataLoader
import numpy as np

class DemoData(Dataset):
    def __init__(self, limit):
        super(DemoData, self).__init__()
        self.data = np.arange(limit)

    def __len__(self):
        return self.data.shape[0]

    def __getitem__(self, idx):
        return (self.data[idx], self.data[idx]*100)

demo = DemoData(100)

loader = DataLoader(demo, batch_size=50, shuffle=True)

for i, (i1, i2) in enumerate(loader):
    print('Batch Index: {}'.format(i))
    print('Shape of data item 1: {}; shape of data item 2: {}'.format(i1.shape, i2.shape))
    # i1, i2 = i1.to('cuda:0'), i2.to('cuda:0')
    print('Device of data item 1: {}; device of data item 2: {}\n'.format(i1.device, i2.device))

Yang akan menghasilkan output sebagai berikut; catatan - tanpa instruksi transfer perangkat yang eksplisit, data dimuat ke CPU:

Batch Index: 0
Shape of data item 1: torch.Size([50]); shape of data item 2: torch.Size([50])
Device of data item 1: cpu; device of data item 2: cpu

Batch Index: 1
Shape of data item 1: torch.Size([50]); shape of data item 2: torch.Size([50])
Device of data item 1: cpu; device of data item 2: cpu

Solusi yang mungkin ada di repo GitHub PyTorch ini. Masalah(masih terbuka pada saat pertanyaan ini diposting), namun, saya tidak dapat membuatnya berfungsi ketika pemuat data harus mengembalikan beberapa item data!


person anurag    schedule 28.01.2021    source sumber
comment
Tidak ada yang pernah mengalami masalah ini?   -  person anurag    schedule 01.02.2021
comment
Parameter collate-fn tidak berguna ketika kumpulan data mengembalikan sejumlah nilai di setiap kumpulan. Jadi masih mencari solusi yang lebih baik!   -  person anurag    schedule 02.02.2021


Jawaban (1)


Anda dapat memodifikasi collate_fn untuk menangani beberapa item sekaligus:

from torch.utils.data.dataloader import default_collate

device = torch.device('cuda:0')  # or whatever device/cpu you like

# the new collate function is quite generic
loader = DataLoader(demo, batch_size=50, shuffle=True, 
                    collate_fn=lambda x: tuple(x_.to(device) for x_ in default_collate(x)))

Perhatikan bahwa jika Anda ingin memiliki banyak pekerja untuk pemuat data, Anda harus menambahkan

torch.multiprocessing.set_start_method('spawn')

setelah if __name__ == '__main__' Anda (lihat masalah ini).

Karena itu, sepertinya menggunakan pin_memory=True di DataLoader Anda akan jauh lebih efisien. Sudahkah Anda mencoba opsi ini?
Lihat penyematan memori untuk mengetahui informasi lebih lanjut.


Pembaruan (8 Februari 2021)
Postingan ini membuat saya melihat waktu data-ke-model yang saya habiskan selama pelatihan. Saya membandingkan tiga alternatif:

  1. DataLoader bekerja pada CPU dan hanya setelah batch diambil, data dipindahkan ke GPU.
  2. Sama seperti (1) tetapi dengan pin_memory=True di DataLoader.
  3. Metode yang diusulkan menggunakan collate_fn untuk memindahkan data ke GPU.

Dari eksperimen saya yang terbatas, sepertinya opsi kedua memiliki performa terbaik (tetapi tidak dengan margin yang besar).
Opsi ketiga memerlukan kerepotan tentang start_method proses pemuat data, dan tampaknya menimbulkan overhead di awal setiap zaman.

person Shai    schedule 02.02.2021
comment
instruksi spawn tidak berfungsi untuk saya, apakah ini khusus untuk versi tertentu? - person anurag; 08.02.2021
comment
@anurag Saya tidak tahu. - person Shai; 08.02.2021