Temukan indeks titik terdekat secara efisien pada kisi 2D non-persegi panjang

Saya memiliki kisi-kisi bujur/latar belakang yang tidak beraturan (bukan persegi panjang) dan sekumpulan titik dalam koordinat bujur/latar belakang, yang harus sesuai dengan titik-titik pada kisi (walaupun mungkin sedikit melenceng karena alasan numerik). Sekarang saya memerlukan indeks titik bujur/latar belakang yang sesuai.

Saya telah menulis sebuah fungsi yang melakukan ini, tetapi ini SANGAT lambat.

def find_indices(lon,lat,x,y):
    lonlat = np.dstack([lon,lat])
    delta = np.abs(lonlat-[x,y])
    ij_1d = np.linalg.norm(delta,axis=2).argmin()
    i,j = np.unravel_index(ij_1d,lon.shape)
    return i,j

ind = [find_indices(lon,lat,p*) for p in points]

Saya cukup yakin ada solusi yang lebih baik (dan lebih cepat) di numpy/scipy. Saya sudah cukup banyak mencari di Google, tetapi sejauh ini jawabannya masih belum saya ketahui.

Adakah saran bagaimana cara menemukan indeks titik (terdekat) yang sesuai secara efisien?

PS: Pertanyaan ini muncul dari pertanyaan lain (klik).

Sunting: Solusi

Berdasarkan jawaban @ Cong Ma, saya menemukan solusi berikut:

def find_indices(points,lon,lat,tree=None):
    if tree is None:
        lon,lat = lon.T,lat.T
        lonlat = np.column_stack((lon.ravel(),lat.ravel()))
        tree = sp.spatial.cKDTree(lonlat)
    dist,idx = tree.query(points,k=1)
    ind = np.column_stack(np.unravel_index(idx,lon.shape))
    return [(i,j) for i,j in ind]

Untuk menempatkan solusi ini dan juga jawaban dari Divakar ke dalam perspektif, berikut adalah beberapa pengaturan waktu dari fungsi yang saya gunakan find_indices (dan di mana hambatannya dalam hal kecepatan) (lihat tautan di atas):

spatial_contour_frequency/pil0                :   331.9553
spatial_contour_frequency/pil1                :   104.5771
spatial_contour_frequency/pil2                :     2.3629
spatial_contour_frequency/pil3                :     0.3287

pil0 adalah pendekatan awal saya, pil1 Divakar, dan pil2/pil3 solusi akhir di atas, di mana pohon dibuat dengan cepat di pil2 (yaitu untuk setiap iterasi loop di mana find_indices dipanggil) dan hanya sekali dalam pil3 ( lihat thread lain untuk detailnya). Meskipun penyempurnaan Divakar pada pendekatan awal saya memberi saya peningkatan kecepatan 3x, cKDTree membawa ini ke tingkat yang benar-benar baru dengan peningkatan kecepatan 50x lainnya! Dan memindahkan kreasi pohon dari fungsinya akan membuat segalanya menjadi lebih cepat.


person flotzilla    schedule 02.10.2015    source sumber
comment
Dalam skrip Anda, Anda membuat pohon baru dengan setiap panggilan ke find_indices. Jika grid Anda diperbaiki di seluruh panggilan, Anda membuang-buang waktu untuk membuat pohon yang sama berulang kali. Sebenarnya konstruksi pohon ini adalah panggilan paling lambat dalam fungsi ini.   -  person Cong Ma    schedule 05.10.2015
comment
Ya, saya perhatikan, itulah yang sedang saya kerjakan saat ini. ;) Saya akan memperbarui jawabannya. Terima kasih atas komentarnya.   -  person flotzilla    schedule 05.10.2015


Jawaban (2)



Berikut adalah pendekatan vektor umum menggunakan scipy.spatial.distance.cdist -

import scipy

# Stack lon and lat arrays as columns to form a Nx2 array, where is N is grid**2
lonlat = np.column_stack((lon.ravel(),lat.ravel()))

# Get the distances and get the argmin across the entire N length
idx = scipy.spatial.distance.cdist(lonlat,points).argmin(0)

# Get the indices corresponding to grid's shape as the final output
ind = np.column_stack((np.unravel_index(idx,lon.shape))).tolist()

Contoh dijalankan -

In [161]: lon
Out[161]: 
array([[-11.   ,  -7.82 ,  -4.52 ,  -1.18 ,   2.19 ],
       [-12.   ,  -8.65 ,  -5.21 ,  -1.71 ,   1.81 ],
       [-13.   ,  -9.53 ,  -5.94 ,  -2.29 ,   1.41 ],
       [-14.1  ,  -0.04 ,  -6.74 ,  -2.91 ,   0.976]])

In [162]: lat
Out[162]: 
array([[-11.2  ,  -7.82 ,  -4.51 ,  -1.18 ,   2.19 ],
       [-12.   ,  -8.63 ,  -5.27 ,  -1.71 ,   1.81 ],
       [-13.2  ,  -9.52 ,  -5.96 ,  -2.29 ,   1.41 ],
       [-14.3  ,  -0.06 ,  -6.75 ,  -2.91 ,   0.973]])

In [163]: lonlat = np.column_stack((lon.ravel(),lat.ravel()))

In [164]: idx = scipy.spatial.distance.cdist(lonlat,points).argmin(0)

In [165]: np.column_stack((np.unravel_index(idx,lon.shape))).tolist()
Out[165]: [[0, 4], [0, 4], [0, 4], [0, 4], [0, 4], [0, 4], [3, 3]]

Tes waktu proses -

Tentukan fungsi:

def find_indices(lon,lat,x,y):
    lonlat = np.dstack([lon,lat])
    delta = np.abs(lonlat-[x,y])
    ij_1d = np.linalg.norm(delta,axis=2).argmin()
    i,j = np.unravel_index(ij_1d,lon.shape)
    return i,j

def loopy_app(lon,lat,pts):
    return [find_indices(lon,lat,pts[i,0],pts[i,1]) for i in range(pts.shape[0])]

def vectorized_app(lon,lat,points):
    lonlat = np.column_stack((lon.ravel(),lat.ravel()))
    idx = scipy.spatial.distance.cdist(lonlat,points).argmin(0)
    return np.column_stack((np.unravel_index(idx,lon.shape))).tolist()

Waktu:

In [179]: lon = np.random.rand(100,100)

In [180]: lat = np.random.rand(100,100)

In [181]: points = np.random.rand(50,2)

In [182]: %timeit loopy_app(lon,lat,points)
10 loops, best of 3: 47 ms per loop

In [183]: %timeit vectorized_app(lon,lat,points)
10 loops, best of 3: 16.6 ms per loop

Untuk meningkatkan kinerja, np.concatenate dapat digunakan dalam tempat np.column_stack.

person Divakar    schedule 03.10.2015