Я только что обнаружил numba и узнал, что оптимальная производительность требует добавления @njit
к большинству функций, так что numba редко выходит из режима LLVM.
У меня все еще есть несколько дорогостоящих/поисковых функций, которые могли бы выиграть от мемоизации, но пока ни одна из моих попыток не нашла работоспособного решения, которое компилируется без ошибок.
- Использование общих функций декоратора до
@njit
приводит к тому, что numba не может делать вывод типа. - Использование декораторов после того, как
@njit
не удалось скомпилировать декоратор - Numba не любит использование
global
переменных, даже при использованииnumba.typed.Dict
- Numba не любит использовать замыкания для хранения изменяемого состояния.
- Удаление
@njit
также вызывает ошибки типа при вызове из других функций@njit
.
Как правильно добавить запоминание функций при работе внутри numba?
import functools
import time
import fastcache
import numba
import numpy as np
import toolz
from numba import njit
from functools import lru_cache
from fastcache import clru_cache
from toolz import memoize
# @fastcache.clru_cache(None) # BUG: Untyped global name 'expensive': cannot determine Numba type of <class 'fastcache.clru_cache'>
# @functools.lru_cache(None) # BUG: Untyped global name 'expensive': cannot determine Numba type of <class 'functools._lru_cache_wrapper'>
# @toolz.memoize # BUG: Untyped global name 'expensive': cannot determine Numba type of <class 'function'>
@njit
# @fastcache.clru_cache(None) # BUG: AttributeError: 'fastcache.clru_cache' object has no attribute '__defaults__'
# @functools.lru_cache(None) # BUG: AttributeError: 'functools._lru_cache_wrapper' object has no attribute '__defaults__'
# @toolz.memoize # BUG: CALL_FUNCTION_EX with **kwargs not supported
def expensive():
bitmasks = np.array([ 1 << n for n in range(0, 64) ], dtype=np.uint64)
return bitmasks
# @fastcache.clru_cache(None) # BUG: Untyped global name 'expensive_nojit': cannot determine Numba type of <class 'fastcache.clru_cache'>
# @functools.lru_cache(None) # BUG: Untyped global name 'expensive_nojit': cannot determine Numba type of <class 'fastcache.clru_cache'>
# @toolz.memoize # BUG: Untyped global name 'expensive_nojit': cannot determine Numba type of <class 'function'>
def expensive_nojit():
bitmasks = np.array([ 1 << n for n in range(0, 64) ], dtype=np.uint64)
return bitmasks
# BUG: Failed in nopython mode pipeline (step: analyzing bytecode)
# Use of unsupported opcode (STORE_GLOBAL) found
_expensive_cache = None
@njit
def expensive_global():
global _expensive_cache
if _expensive_cache is None:
bitmasks = np.array([ 1 << n for n in range(0, 64) ], dtype=np.uint64)
_expensive_cache = bitmasks
return _expensive_cache
# BUG: The use of a DictType[unicode_type,array(int64, 1d, A)] type, assigned to variable 'cache' in globals,
# is not supported as globals are considered compile-time constants and there is no known way to compile
# a DictType[unicode_type,array(int64, 1d, A)] type as a constant.
cache = numba.typed.Dict.empty(
key_type = numba.types.string,
value_type = numba.uint64[:]
)
@njit
def expensive_cache():
global cache
if "expensive" not in cache:
bitmasks = np.array([ 1 << n for n in range(0, 64) ], dtype=np.uint64)
cache["expensive"] = bitmasks
return cache["expensive"]
# BUG: Cannot capture the non-constant value associated with variable 'cache' in a function that will escape.
@njit()
def _expensive_wrapped():
cache = []
def wrapper(bitmasks):
if len(cache) is None:
bitmasks = np.array([ 1 << n for n in range(0, 64) ], dtype=np.uint64)
cache.append(bitmasks)
return cache[0]
return wrapper
expensive_wrapped = _expensive_wrapped()
@njit
def loop(count):
for n in range(count):
expensive()
# expensive_nojit()
# expensive_cache()
# expensive_global)
# expensive_wrapped()
def main():
time_start = time.perf_counter()
count = 10000
loop(count)
time_taken = time.perf_counter() - time_start
print(f'{count} loops in {time_taken:.4f}s')
loop(1) # precache numba
main()
# Pure Python: 10000 loops in 0.2895s
# Numba @njit: 10000 loops in 0.0026s