Memoisasi yang Kompatibel dengan Numba

Saya baru saja menemukan numba, dan mengetahui bahwa kinerja optimal memerlukan penambahan @njit ke sebagian besar fungsi, sehingga numba jarang keluar dari mode LLVM.

Saya masih memiliki beberapa fungsi mahal/pencarian yang dapat memanfaatkan memoisasi, tetapi sejauh ini tidak ada upaya saya yang menemukan solusi yang bisa diterapkan yang dapat dikompilasi tanpa kesalahan.

  • Menggunakan fungsi dekorator umum, sebelum @njit mengakibatkan numba tidak dapat melakukan inferensi tipe.
  • Menggunakan dekorator setelah @njit gagal mengkompilasi dekorator
  • Numba tidak menyukai penggunaan variabel global, bahkan saat menggunakan numba.typed.Dict
  • Numba tidak suka menggunakan penutupan untuk menyimpan keadaan yang bisa berubah
  • Menghapus @njit juga menyebabkan kesalahan ketik saat dipanggil dari fungsi @njit lainnya

Apa cara yang benar untuk menambahkan memoisasi ke fungsi saat bekerja di dalam numba?

import functools
import time

import fastcache
import numba
import numpy as np
import toolz
from numba import njit

from functools import lru_cache
from fastcache import clru_cache
from toolz import memoize



# @fastcache.clru_cache(None)  # BUG: Untyped global name 'expensive': cannot determine Numba type of <class 'fastcache.clru_cache'>
# @functools.lru_cache(None)   # BUG: Untyped global name 'expensive': cannot determine Numba type of <class 'functools._lru_cache_wrapper'>
# @toolz.memoize               # BUG: Untyped global name 'expensive': cannot determine Numba type of <class 'function'>
@njit
# @fastcache.clru_cache(None)  # BUG: AttributeError: 'fastcache.clru_cache' object has no attribute '__defaults__'
# @functools.lru_cache(None)   # BUG: AttributeError: 'functools._lru_cache_wrapper' object has no attribute '__defaults__'
# @toolz.memoize               # BUG: CALL_FUNCTION_EX with **kwargs not supported
def expensive():
    bitmasks = np.array([ 1 << n for n in range(0, 64) ], dtype=np.uint64)
    return bitmasks



# @fastcache.clru_cache(None)  # BUG: Untyped global name 'expensive_nojit': cannot determine Numba type of <class 'fastcache.clru_cache'>
# @functools.lru_cache(None)   # BUG: Untyped global name 'expensive_nojit': cannot determine Numba type of <class 'fastcache.clru_cache'>
# @toolz.memoize               # BUG: Untyped global name 'expensive_nojit': cannot determine Numba type of <class 'function'>
def expensive_nojit():
    bitmasks = np.array([ 1 << n for n in range(0, 64) ], dtype=np.uint64)
    return bitmasks


# BUG: Failed in nopython mode pipeline (step: analyzing bytecode)
#      Use of unsupported opcode (STORE_GLOBAL) found
_expensive_cache = None
@njit
def expensive_global():
    global _expensive_cache
    if _expensive_cache is None:
        bitmasks = np.array([ 1 << n for n in range(0, 64) ], dtype=np.uint64)
        _expensive_cache = bitmasks
    return _expensive_cache


# BUG: The use of a DictType[unicode_type,array(int64, 1d, A)] type, assigned to variable 'cache' in globals,
#      is not supported as globals are considered compile-time constants and there is no known way to compile
#      a DictType[unicode_type,array(int64, 1d, A)] type as a constant.
cache = numba.typed.Dict.empty(
    key_type   = numba.types.string,
    value_type = numba.uint64[:]
)
@njit
def expensive_cache():
    global cache
    if "expensive" not in cache:
        bitmasks = np.array([ 1 << n for n in range(0, 64) ], dtype=np.uint64)
        cache["expensive"] = bitmasks
    return cache["expensive"]


# BUG: Cannot capture the non-constant value associated with variable 'cache' in a function that will escape.
@njit()
def _expensive_wrapped():
    cache = []
    def wrapper(bitmasks):
        if len(cache) is None:
            bitmasks = np.array([ 1 << n for n in range(0, 64) ], dtype=np.uint64)
            cache.append(bitmasks)
        return cache[0]
    return wrapper
expensive_wrapped = _expensive_wrapped()

@njit
def loop(count):
    for n in range(count):
        expensive()
        # expensive_nojit()
        # expensive_cache()
        # expensive_global)
        # expensive_wrapped()

def main():
    time_start = time.perf_counter()

    count = 10000
    loop(count)

    time_taken = time.perf_counter() - time_start
    print(f'{count} loops in {time_taken:.4f}s')


loop(1)  # precache numba
main()

# Pure Python: 10000 loops in 0.2895s
# Numba @njit: 10000 loops in 0.0026s

person James McGuigan    schedule 27.07.2020    source sumber
comment
Apakah ini contoh dunia nyata Anda? Menghitung fungsi yang mahal dengan beberapa pengoptimalan bila diperlukan sebenarnya lebih cepat daripada pencarian memori (menghilangkan pemahaman daftar dan meneruskan blok memori yang telah dialokasikan sebelumnya ke dalam fungsi). Bekerja dengan dunia nyata merupakan hal yang rumit di Numba dan membutuhkan solusi. Misalnya, Anda dapat meneruskan alamat memori (sebagai int64) ke dalam fungsi yang dikompilasi dan bekerja dengannya. (diperlukan callable tingkat rendah yang mentransmisikan alamat memori int64 ke sebuah pointer). misalnya. stackoverflow.com/a/61550054/4045774   -  person max9111    schedule 28.07.2020
comment
Terima kasih atas tautan dan ide menggunakan pointer. Kode dunia nyata saya sedikit lebih rumit (Kompetisi Kaggle Connect 4), jadi saya mencoba menyederhanakan numba menjadi testcase POC minimalis untuk pertanyaan ini.   -  person James McGuigan    schedule 28.07.2020