Совместимая с Numba мемоизация

Я только что обнаружил numba и узнал, что оптимальная производительность требует добавления @njit к большинству функций, так что numba редко выходит из режима LLVM.

У меня все еще есть несколько дорогостоящих/поисковых функций, которые могли бы выиграть от мемоизации, но пока ни одна из моих попыток не нашла работоспособного решения, которое компилируется без ошибок.

  • Использование общих функций декоратора до @njit приводит к тому, что numba не может делать вывод типа.
  • Использование декораторов после того, как @njit не удалось скомпилировать декоратор
  • Numba не любит использование global переменных, даже при использовании numba.typed.Dict
  • Numba не любит использовать замыкания для хранения изменяемого состояния.
  • Удаление @njit также вызывает ошибки типа при вызове из других функций @njit.

Как правильно добавить запоминание функций при работе внутри numba?

import functools
import time

import fastcache
import numba
import numpy as np
import toolz
from numba import njit

from functools import lru_cache
from fastcache import clru_cache
from toolz import memoize



# @fastcache.clru_cache(None)  # BUG: Untyped global name 'expensive': cannot determine Numba type of <class 'fastcache.clru_cache'>
# @functools.lru_cache(None)   # BUG: Untyped global name 'expensive': cannot determine Numba type of <class 'functools._lru_cache_wrapper'>
# @toolz.memoize               # BUG: Untyped global name 'expensive': cannot determine Numba type of <class 'function'>
@njit
# @fastcache.clru_cache(None)  # BUG: AttributeError: 'fastcache.clru_cache' object has no attribute '__defaults__'
# @functools.lru_cache(None)   # BUG: AttributeError: 'functools._lru_cache_wrapper' object has no attribute '__defaults__'
# @toolz.memoize               # BUG: CALL_FUNCTION_EX with **kwargs not supported
def expensive():
    bitmasks = np.array([ 1 << n for n in range(0, 64) ], dtype=np.uint64)
    return bitmasks



# @fastcache.clru_cache(None)  # BUG: Untyped global name 'expensive_nojit': cannot determine Numba type of <class 'fastcache.clru_cache'>
# @functools.lru_cache(None)   # BUG: Untyped global name 'expensive_nojit': cannot determine Numba type of <class 'fastcache.clru_cache'>
# @toolz.memoize               # BUG: Untyped global name 'expensive_nojit': cannot determine Numba type of <class 'function'>
def expensive_nojit():
    bitmasks = np.array([ 1 << n for n in range(0, 64) ], dtype=np.uint64)
    return bitmasks


# BUG: Failed in nopython mode pipeline (step: analyzing bytecode)
#      Use of unsupported opcode (STORE_GLOBAL) found
_expensive_cache = None
@njit
def expensive_global():
    global _expensive_cache
    if _expensive_cache is None:
        bitmasks = np.array([ 1 << n for n in range(0, 64) ], dtype=np.uint64)
        _expensive_cache = bitmasks
    return _expensive_cache


# BUG: The use of a DictType[unicode_type,array(int64, 1d, A)] type, assigned to variable 'cache' in globals,
#      is not supported as globals are considered compile-time constants and there is no known way to compile
#      a DictType[unicode_type,array(int64, 1d, A)] type as a constant.
cache = numba.typed.Dict.empty(
    key_type   = numba.types.string,
    value_type = numba.uint64[:]
)
@njit
def expensive_cache():
    global cache
    if "expensive" not in cache:
        bitmasks = np.array([ 1 << n for n in range(0, 64) ], dtype=np.uint64)
        cache["expensive"] = bitmasks
    return cache["expensive"]


# BUG: Cannot capture the non-constant value associated with variable 'cache' in a function that will escape.
@njit()
def _expensive_wrapped():
    cache = []
    def wrapper(bitmasks):
        if len(cache) is None:
            bitmasks = np.array([ 1 << n for n in range(0, 64) ], dtype=np.uint64)
            cache.append(bitmasks)
        return cache[0]
    return wrapper
expensive_wrapped = _expensive_wrapped()

@njit
def loop(count):
    for n in range(count):
        expensive()
        # expensive_nojit()
        # expensive_cache()
        # expensive_global)
        # expensive_wrapped()

def main():
    time_start = time.perf_counter()

    count = 10000
    loop(count)

    time_taken = time.perf_counter() - time_start
    print(f'{count} loops in {time_taken:.4f}s')


loop(1)  # precache numba
main()

# Pure Python: 10000 loops in 0.2895s
# Numba @njit: 10000 loops in 0.0026s

person James McGuigan    schedule 27.07.2020    source источник
comment
Это ваш реальный пример? Вычисление дорогостоящей функции с некоторыми оптимизациями, когда это необходимо, на самом деле быстрее, чем поиск в памяти (избавление от понимания списка и передача предварительно выделенного блока памяти в функцию). Работа с реальными глобальными переменными в Numba сложна и требует обходных путей. Вы можете передать, например, адрес памяти (в виде int64) в скомпилированную функцию и работать с ней. (необходим вызов низкого уровня, который приводит адрес памяти int64 к указателю). например. stackoverflow.com/a/61550054/4045774   -  person max9111    schedule 28.07.2020
comment
Спасибо за ссылку и идею использования указателей. Мой реальный код немного сложнее (Kaggle Connect 4 Competition), поэтому я попытался упростить numba до минималистского тестового примера POC для этого вопроса.   -  person James McGuigan    schedule 28.07.2020