Unit Scaling adalah metode pembelajaran mesin presisi rendah baru yang mampu melatih model bahasa di FP16 dan FP8 tanpa kehilangan penskalaan.

Penulis: Charlie Blake, Peneliti AI di Graphcore

"Baca korannya" | Kode | Buku catatan demo PyTorch

Dalam beberapa tahun terakhir, komunitas pembelajaran mendalam telah beralih dari format angka FP32 ke format FP16 dan BFLOAT16. Hal ini menyebabkan pengurangan besar dalam memori, bandwidth, dan kebutuhan komputasi — yang semuanya penting untuk tren model yang semakin besar.

Sekarang, dengan pengembangan perangkat keras yang mendukung FP8 (seperti prosesor Graphcore IPU Bow yang digunakan dalam kartu PCIe C600) penghematan efisiensi presisi rendah lebih lanjut dapat dilakukan. Namun, sejauh ini format yang lebih kecil dan berpresisi rendah ini tidak selalu mudah digunakan dalam praktiknya. Dengan FP8, hal ini mungkin menjadi lebih sulit.

Tantangan yang paling signifikan adalah format yang lebih kecil ini sering kali membatasi pengguna pada rentang nilai yang lebih sempit dan dapat direpresentasikan. Pertanyaan yang muncul adalah: bagaimana kita memastikan bahwa model kita tetap berada dalam rentang format yang lebih kecil? Untuk mengatasi hal ini, Graphcore Research telah mengembangkan metode baru, yang kami beri nama penskalaan unit.

Penskalaan unit adalah teknik desain model yang beroperasi berdasarkan prinsip penskalaan ideal pada inisialisasi; yaitu varian satuan untuk aktivasi, bobot, dan gradien. Hal ini dicapai dengan mempertimbangkan perubahan varians yang ditimbulkan oleh setiap operasi dalam model dan memperkenalkan faktor penskalaan tetap untuk mengatasi hal ini.

Model yang dihasilkan secara otomatis menghasilkan tensor dengan skala yang baik untuk format angka dengan presisi rendah, sehingga penggunaannya menjadi mudah dan meminimalkan kelemahan dari representasi yang sangat efisien ini. Biaya tambahan dan kompleksitas tambahan yang diterapkan sangat minim, tidak seperti pendekatan alternatif pada pelatihan presisi rendah.

Metode kami mencapai hasil terobosan: untuk pertama kalinya, kami telah melatih model BERT Base dan BERT Large secara akurat di FP16 dan bahkan FP8 tanpa penskalaan kerugian. Penskalaan unit dapat dilakukan secara langsung, tanpa memerlukan sapuan atau hyperparameter tambahan untuk pelatihan. Model skala unit kemudian dapat digunakan untuk inferensi tanpa batasan atau modifikasi tambahan.

Bagi praktisi yang peduli dengan efisiensi — dan karenanya ingin berlatih di FP16 dan FP8 — penskalaan unit menawarkan solusi yang mudah. IPU ini sangat cocok untuk kasus penggunaan ini, dengan prosesor Bow IPU dari Graphcore saat ini yang menyediakan komputasi FP16 yang dipercepat, dan perangkat keras IPU generasi berikutnya menambahkan komputasi FP8 yang dipercepat. Pengguna dapat mencoba sendiri penskalaan unit melalui Paperspace notebook yang menyertainya.

Pendekatan yang ada untuk pelatihan FP16/FP8

Pelatihan FP16 dan FP8 memerlukan beberapa bentuk penskalaan untuk menjaga nilai dalam jangkauan. Pendekatan yang ada saat ini adalah sebagai berikut:

Penskalaan kerugian (Statis).

Pengurangan jangkauan sangat menantang untuk gerakan mundur selama latihan, yang sering kali menyebabkan penurunan gradien. Untuk mengatasi hal ini, salah satu pendekatannya adalah dengan mengalikan kerugian dengan hyperparameter skala kerugian untuk meningkatkan ukuran gradien [1]. Karena tidak ada cara yang berprinsip untuk memilih skala kerugian sebelumnya, hyperparameter ini mungkin perlu disapu, seringkali memerlukan beberapa kali proses penuh.

Penskalaan kerugian otomatis

Seseorang dapat menghindari kebutuhan akan penyapuan hyperparameter dengan menyesuaikan skala kerugian secara dinamis berdasarkan luapan gradien run-time (atau histogram) [2]. Hal ini juga dapat mengatasi pergeseran distribusi tensor selama pelatihan. Sayangnya, skema otomatis dapat menambah biaya tambahan dan kompleksitas.

Penskalaan per-tensor

Kelemahan lain dari metode di atas adalah metode ini hanya menyediakan satu skala kerugian global. Salah satu solusi yang diusulkan adalah menskalakan ulang nilai secara lokal berdasarkan statistik tensor [3]. Ini juga merupakan skema otomatis/run-time, sehingga mungkin rumit dan sulit diterapkan secara efisien.

Penskalaan unit juga memperkenalkan faktor penskalaan lokal dalam proses maju dan mundur untuk mengontrol rentang nilai. Namun, kami memilih faktor-faktor ini berdasarkan pemahaman teoretis tentang bagaimana masing-masing operator memengaruhi skala nilai, dibandingkan menggunakan analisis run-time.

Dengan memilih faktor penskalaan yang tepat, setiap operasi akan mempertahankan skala inputnya. Dengan menerapkan hal ini pada semua operasi, hal ini akan menyebarkan skala (unit) awal ke seluruh model, sehingga memberikan penskalaan unit secara global.

Perhatikan bahwa analisis kami didasarkan pada skala nilai pada inisialisasi, sebelum pelatihan dimulai. Meskipun skala berubah selama pelatihan, kami menemukan bahwa penskalaan awal yang baik memberikan ruang yang cukup sehingga penskalaan ulang tidak diperlukan (pekerjaan selanjutnya akan menyelidiki arah ini lebih lanjut, mengevaluasi kemungkinan penskalaan ulang pada interval yang lebih lama saat kami beralih ke model yang lebih besar).

Metode kami lebih sederhana daripada skema penskalaan otomatis, dan satu-satunya overhead tambahan adalah penerapan faktor penskalaan (perkalian skalar, yang dapat digabungkan ke dalam operasi sebelumnya). Untuk BERT Besar, hal ini menyebabkan peningkatan FLOP sebesar 0,2%.

resep

Sebuah model dapat diskalakan satuan dengan menerapkan resep berikut:

  1. Inisialisasi parameter non-bias dengan unit varians.
  2. Hitung faktor penskalaan ideal untuk semua operasi.
  3. Identifikasi non-cut-edge dan batasi operasi yang mengkonsumsinya agar memiliki skala yang sama.
  4. Ganti penambahan dengan penambahan berbobot.

Kami menjelaskan aturan-aturan ini secara lebih rinci di bawah.

Faktor penskalaan yang ideal

Kita dapat menganalisis beberapa operasi secara matematis untuk menentukan pengaruhnya terhadap varians inputnya.

Misalnya, perkalian matriks dasar XW(dengan X adalah (b × m)matriks dan W adalah (m × n)matriks) memiliki varian keluaran σ(X · σ(W· m. Untuk melakukan operasi ini pada skala satuan, kita harus memastikan σ(X= σ(W )² = 1(dengan menskalakan operasi sebelumnya), lalu menambahkan perkalian 1/√m​ ke hasilnya.

Ini menjelaskan umpan ke depan. Proses mundur memperkenalkan dua perkalian matriks baru, dengan faktor penskalaan ideal sebesar 1/√ndan 1/√b. Operasi lain dapat dianalisis dengan cara yang sama, dan jika varian keluaran tidak dapat dianalisis dengan mudah, metode empiris dapat digunakan untuk menemukan faktor penskalaan.

Kami memberikan analisis yang lebih rinci dalam “makalah arxiv” kami, bersama dengan ringkasan operasi umum dan faktor penskalaan idealnya.

Potong pinggirannya

Penerapan langsung faktor penskalaan ideal ini pada lintasan maju dan mundur dapat menghasilkan gradien yang tidak valid. Untuk menghindari hal ini, kami mengharuskan operasi tertentu menggunakan faktor penskalaan bersama.

Secara khusus, kami mengambil grafik komputasi maju dan menemukan semua variabel yang tidak diwakili oleh tepi potong (tepi yang jika dihilangkan, akan membagi grafik menjadi dua grafik kecil yang tidak terhubung ). Berikut ini tampilan lapisan FFN transformator:

Dalam hal ini, kami memiliki keunggulan dalam variabel bobot, input, dan output. Diagram juga menunjukkan operasi gradien yang dihasilkan untuk gerakan mundur matmul kedua (catatan: kami hanya mempertimbangkan sisi potong untuk grafik maju).

Kita membatasi matmul untuk ∇x₃​​ agar menggunakan faktor penskalaan yang sama seperti pada operan maju, karena x₃​​​ bukanlah cut-edge. Namun, karena w₂​adalah yang terdepan, hal ini memungkinkan adanya faktor penskalaan ke belakang. Untuk memilih faktor penskalaan bersama untuk operasi terbatas, kami mengambil rata-rata geometrik dari faktor penskalaan ideal yang dihitung sebelumnya.

Meskipun aturan mutakhir ini mungkin terdengar rumit, dalam praktiknya biasanya aturan ini bermuara pada prosedur sederhana: memberikan gradien bobot pada faktor penskalaannya sendiri, serta lapisan encoder/decoder dalam model.

Penambahan tertimbang

Langkah terakhir dari resep kita adalah mengganti operasi penjumlahan dengan penambahan tertimbang. Penskalaan unit secara desain menghasilkan variabel dengan skala yang sama, artinya jika kita menambahkan dua tensor, keduanya secara efektif memiliki bobot yang sama. Namun, dalam beberapa kasus, khususnya sambungan sisa, kami mungkin memerlukan bobot yang tidak seimbang untuk mencapai kinerja yang baik.

Untuk memperhitungkan hal ini, kami mengganti operasi penjumlahan dengan operasi berbobot (dan skala unit) yang setara. Untuk sambungan sisa, kami menggunakan ini untuk mendapatkan skema yang direkomendasikan berikut:

Penerapan

Kode berikut menunjukkan implementasi lapisan FFN berskala unit di PyTorch. Kami memberikan contoh implementasi lebih lanjut di basis kode dan buku catatan demo kami.

Pertama-tama kita mendefinisikan beberapa primitif penskalaan, yang memungkinkan kita membuat versi operasi dasar yang berskala, seperti scaled_projection:

class ScaledGrad(autograd.Function):
  @staticmethod
  def forward(ctx, X, alpha, beta):
    ctx.save_for_backward(tensor(beta, dtype=X.dtype))
    return alpha * X

  @staticmethod
  def backward(ctx, grad_Y):
    beta, = ctx.saved_tensors
    return beta * grad_Y, None, None

def scaled(X, alpha=1, beta=1):
  """forward: Y = X * alpha, backward: grad_X = grad_Y * beta"""
  return ScaledGrad.apply(X, alpha, beta)

def scaled_projection(X, W):
  (b, _), (m, n) = X.shape, W.shape
  alpha = beta_X = (m * n) ** -(1/4) beta_W = b ** -(1/2)
  X = scaled(X, beta=beta_X)
  W = scaled(W, beta=beta_W)
  return scaled(matmul(X, W), alpha)

Ini kemudian memungkinkan kita membuat lapisan berskala penuh. Di sini kami mendemonstrasikan FFN standar dan skala unit yang setara:

class FFN(nn.Module):
  def __init__(self, d, h):
    super().__init__()
    self.norm = LayerNorm(d)
    sigma = (d * h) ** -(1/4)
    self.W_1 = Parameter(randn(d, h) * sigma)
    self.W_2 = Parameter(randn(h, d) * sigma)

  def forward(self, X):
    Z = self.norm(X)
    Z = matmul(Z, self.W_1) Z = gelu(Z)
    Z = matmul(Z, self.W_2) return X + Z


class ScaledFFN(nn.Module):
  def __init__(self, d, h, tau):
    super().__init__()
    self.norm = ScaledLayerNorm(d)  # Not defined here
    self.W1 = Parameter(randn(d, h))
    self.W2 = Parameter(randn(h, d))
    self.tau = tau

  def forward(self, X):
    a = (1 - self.tau) ** (1/2)
    b = self.tau ** (1/2)
    Z = self.norm(scaled(X, beta=b))
    Z = scaled_projection(Z, self.W1)
    Z = scaled_gelu(Z)  # Not defined here
    Z = scaled_projection(Z, self.W2)
    return X * a + scaled(Z, b)  # fixed(𝜏) weighted add

Hasil

Hasil eksperimen kami menunjukkan bahwa penskalaan unit efektif di berbagai model, dan berfungsi dengan baik, tanpa memerlukan penyesuaian hyperparameter tambahan.

Eksperimen skala kecil

Rangkaian eksperimen pertama kami memvalidasi penerapan luas penskalaan unit di berbagai arsitektur model. Kami melatih berbagai macam model bahasa tingkat karakter yang lebih kecil dengan dan tanpa penskalaan unit, di FP32 dan FP16, dan membandingkan hasilnya. Konfigurasi ini merupakan penyisiran yang dilakukan pada tahun 2092:

Hasil kami menunjukkan hal berikut: pertama, diperlukan beberapa bentuk penskalaan (kerugian atau unit) saat menggunakan FP16. Hal ini disebabkan oleh penurunan gradien, karena penskalaan kerugian dengan faktor 2048 menyelesaikan masalah tersebut. Kedua, penskalaan unit, meskipun mengubah perilaku pelatihan model lebih dari sekadar numerik, mencocokkan atau bahkan sedikit meningkatkan kinerja dasar di hampir semua kasus. Terakhir, tidak diperlukan penyetelan saat mengalihkan penskalaan unit dari FP32 ke FP16.

Eksperimen skala besar

Rangkaian eksperimen kedua kami memvalidasi efektivitas penskalaan unit pada model tingkat produksi yang lebih besar dan realistis, BERT [4]. Kami menerapkan penyesuaian pada model skala unit kami untuk menyelaraskannya dengan implementasi BERT standar, dan kemudian melatihnya pada teks dari artikel Wikipedia bahasa Inggris.

Hasil kami pada tugas evaluasi SQuAD v1.0 dan SQuAD v2.0 adalah sebagai berikut:

Penskalaan unit mampu mencapai kinerja yang sama dengan model standar (garis dasar), dan meskipun garis dasar memerlukan skala kerugian yang luas, penskalaan unit berfungsi dengan baik di semua kasus. Model dasar dan skala unit tidak sepenuhnya setara, namun penyimpangan dalam kinerja hilirnya kecil (BERT Base skala unit sedikit di bawah garis dasar, dan BERT Besar sedikit di atas).

Implementasi FP8 kami didasarkan pada format yang “baru-baru ini diusulkan” untuk standardisasi oleh Graphcore, AMD dan Qualcomm. Penelitian Graphcore sebelumnya menunjukkan pelatihan BERT skala kerugian di FP8 tanpa degradasi [5], dan sekarang kami menunjukkan bahwa hal yang sama dapat dicapai dengan penskalaan unit.

Tidak diperlukan teknik tambahan untuk membuat FP8 berfungsi di atas FP16. Kami cukup menghitung input matmul kami ke dalam FP8 dan dapat berlatih secara akurat (dengan bobot dan aktivasi di varian FP8 E4, dan gradien di E5). Hasil ini menunjukkan pertama kalinya BERT Base atau BERT Large dilatih di FP16 atau FP8 tanpa memerlukan penskalaan kerugian.

Masa depan pelatihan presisi rendah

Seiring dengan berkembangnya adopsi perangkat keras dengan dukungan FP8 dalam komunitas AI, pentingnya pendekatan penskalaan model yang efektif, lugas, dan berprinsip juga akan meningkat. Penskalaan unit memenuhi semua kriteria ini. Hal ini juga berlaku di berbagai model dan pengoptimal, dengan overhead komputasi yang minimal.

Model besar generasi berikutnya kemungkinan besar akan menggunakan format presisi rendah secara ekstensif, dan karenanya mungkin memerlukan pendekatan seperti penskalaan unit. Kami berharap metode kami dapat berguna untuk aplikasi ini, dan juga memberikan landasan yang kuat untuk penelitian penskalaan di masa depan. Manfaat efisiensi dari pelatihan dengan presisi rendah sangat besar, dan penskalaan unit menunjukkan bahwa pelatihan tersebut tidak memerlukan biaya.

"Baca korannya" | Kode | Buku catatan demo PyTorch

Referensi

[1] P. Micikevicius dkk., Pelatihan presisi campuran (2018). Konferensi Internasional ke-6 tentang Representasi Pembelajaran

[2] O. Kuchaiev et al., Pelatihan presisi campuran untuk nlp dan pengenalan ucapan dengan openseq2seq (2018), arXiv preprint arXiv:1805.10387

[3] P. Micikevicius dkk., Format FP8 untuk pembelajaran mendalam (2022). arXiv pracetak arXiv:2209.05433

[4] J. Devlin dkk., BERT: Pra-pelatihan transformator dua arah yang mendalam untuk pemahaman bahasa (2019). NAACL-HLT

[5] B. Noune dkk., format numerik 8-bit untuk jaringan saraf dalam (2019). arXiv pracetak arXiv:2206.02915