ความทรงจำที่เข้ากันได้กับ Numba

ฉันเพิ่งค้นพบ numba และได้เรียนรู้ว่าประสิทธิภาพสูงสุดจำเป็นต้องเพิ่ม @njit ให้กับฟังก์ชันส่วนใหญ่ เพื่อให้ numba แทบจะไม่ออกจากโหมด LLVM

ฉันยังมีฟังก์ชันการค้นหา/ราคาแพงอยู่บ้างที่อาจได้รับประโยชน์จากการท่องจำ แต่จนถึงขณะนี้ยังไม่มีความพยายามใดของฉันที่พบวิธีแก้ปัญหาที่ใช้งานได้ซึ่งคอมไพล์โดยไม่มีข้อผิดพลาด

  • การใช้ฟังก์ชันมัณฑนากรทั่วไป ก่อน @njit ส่งผลให้ numba ไม่สามารถทำการอนุมานประเภทได้
  • การใช้มัณฑนากรหลังจาก @njit ไม่สามารถคอมไพล์มัณฑนากรได้
  • Numba ไม่ชอบการใช้ตัวแปร global แม้ว่าจะใช้ numba.typed.Dict
  • Numba ไม่ชอบการใช้การปิดเพื่อจัดเก็บสถานะที่ไม่แน่นอน
  • การลบ @njit ยังทำให้เกิดข้อผิดพลาดประเภทเมื่อเรียกจากฟังก์ชัน @njit อื่นๆ

วิธีที่ถูกต้องในการเพิ่มการจดจำให้กับฟังก์ชั่นเมื่อทำงานภายใน numba คืออะไร?

import functools
import time

import fastcache
import numba
import numpy as np
import toolz
from numba import njit

from functools import lru_cache
from fastcache import clru_cache
from toolz import memoize



# @fastcache.clru_cache(None)  # BUG: Untyped global name 'expensive': cannot determine Numba type of <class 'fastcache.clru_cache'>
# @functools.lru_cache(None)   # BUG: Untyped global name 'expensive': cannot determine Numba type of <class 'functools._lru_cache_wrapper'>
# @toolz.memoize               # BUG: Untyped global name 'expensive': cannot determine Numba type of <class 'function'>
@njit
# @fastcache.clru_cache(None)  # BUG: AttributeError: 'fastcache.clru_cache' object has no attribute '__defaults__'
# @functools.lru_cache(None)   # BUG: AttributeError: 'functools._lru_cache_wrapper' object has no attribute '__defaults__'
# @toolz.memoize               # BUG: CALL_FUNCTION_EX with **kwargs not supported
def expensive():
    bitmasks = np.array([ 1 << n for n in range(0, 64) ], dtype=np.uint64)
    return bitmasks



# @fastcache.clru_cache(None)  # BUG: Untyped global name 'expensive_nojit': cannot determine Numba type of <class 'fastcache.clru_cache'>
# @functools.lru_cache(None)   # BUG: Untyped global name 'expensive_nojit': cannot determine Numba type of <class 'fastcache.clru_cache'>
# @toolz.memoize               # BUG: Untyped global name 'expensive_nojit': cannot determine Numba type of <class 'function'>
def expensive_nojit():
    bitmasks = np.array([ 1 << n for n in range(0, 64) ], dtype=np.uint64)
    return bitmasks


# BUG: Failed in nopython mode pipeline (step: analyzing bytecode)
#      Use of unsupported opcode (STORE_GLOBAL) found
_expensive_cache = None
@njit
def expensive_global():
    global _expensive_cache
    if _expensive_cache is None:
        bitmasks = np.array([ 1 << n for n in range(0, 64) ], dtype=np.uint64)
        _expensive_cache = bitmasks
    return _expensive_cache


# BUG: The use of a DictType[unicode_type,array(int64, 1d, A)] type, assigned to variable 'cache' in globals,
#      is not supported as globals are considered compile-time constants and there is no known way to compile
#      a DictType[unicode_type,array(int64, 1d, A)] type as a constant.
cache = numba.typed.Dict.empty(
    key_type   = numba.types.string,
    value_type = numba.uint64[:]
)
@njit
def expensive_cache():
    global cache
    if "expensive" not in cache:
        bitmasks = np.array([ 1 << n for n in range(0, 64) ], dtype=np.uint64)
        cache["expensive"] = bitmasks
    return cache["expensive"]


# BUG: Cannot capture the non-constant value associated with variable 'cache' in a function that will escape.
@njit()
def _expensive_wrapped():
    cache = []
    def wrapper(bitmasks):
        if len(cache) is None:
            bitmasks = np.array([ 1 << n for n in range(0, 64) ], dtype=np.uint64)
            cache.append(bitmasks)
        return cache[0]
    return wrapper
expensive_wrapped = _expensive_wrapped()

@njit
def loop(count):
    for n in range(count):
        expensive()
        # expensive_nojit()
        # expensive_cache()
        # expensive_global)
        # expensive_wrapped()

def main():
    time_start = time.perf_counter()

    count = 10000
    loop(count)

    time_taken = time.perf_counter() - time_start
    print(f'{count} loops in {time_taken:.4f}s')


loop(1)  # precache numba
main()

# Pure Python: 10000 loops in 0.2895s
# Numba @njit: 10000 loops in 0.0026s

person James McGuigan    schedule 27.07.2020    source แหล่งที่มา
comment
นี่คือตัวอย่างในโลกแห่งความเป็นจริงของคุณใช่ไหม? การคำนวณฟังก์ชันที่มีราคาแพงพร้อมการปรับให้เหมาะสมบางอย่างเมื่อจำเป็นนั้นเร็วกว่าการค้นหาหน่วยความจำ (กำจัดความเข้าใจในรายการและส่งบล็อกหน่วยความจำที่จัดสรรไว้ล่วงหน้าไปยังฟังก์ชัน) การทำงานกับ Real Globals นั้นซับซ้อนใน Numba และจำเป็นต้องได้รับการแก้ไข เช่น คุณสามารถส่งที่อยู่หน่วยความจำ (เป็น int64) ไปยังฟังก์ชันที่คอมไพล์แล้วใช้งานได้ (จำเป็นที่สามารถเรียกได้ระดับต่ำซึ่งจะแปลงที่อยู่หน่วยความจำ int64 ไปยังตัวชี้) เช่น. stackoverflow.com/a/61550054/4045774   -  person max9111    schedule 28.07.2020
comment
ขอบคุณสำหรับลิงค์และแนวคิดในการใช้พอยน์เตอร์ รหัสโลกแห่งความเป็นจริงของฉันซับซ้อนกว่าเล็กน้อย (การแข่งขัน Kaggle Connect 4) ดังนั้นฉันจึงพยายามลดความซับซ้อนของ numba ลงไปเป็นกรณีทดสอบ POC ที่เรียบง่ายสำหรับคำถามนี้   -  person James McGuigan    schedule 28.07.2020