ค้นหาดัชนีของจุดที่ใกล้ที่สุดบนตาราง 2D ที่ไม่ใช่สี่เหลี่ยมได้อย่างมีประสิทธิภาพ

ฉันมีตาราง lon/lat ที่ผิดปกติ (ไม่ใช่สี่เหลี่ยม) และมีจุดจำนวนมากในพิกัด lon/lat ซึ่งควรสอดคล้องกับจุดบนตาราง (แม้ว่ามันอาจจะผิดไปเล็กน้อยด้วยเหตุผลเชิงตัวเลข) ตอนนี้ฉันต้องการดัชนีของจุด lon/lat ที่สอดคล้องกัน

ฉันได้เขียนฟังก์ชันที่ทำสิ่งนี้ แต่มันช้ามาก

def find_indices(lon,lat,x,y):
    lonlat = np.dstack([lon,lat])
    delta = np.abs(lonlat-[x,y])
    ij_1d = np.linalg.norm(delta,axis=2).argmin()
    i,j = np.unravel_index(ij_1d,lon.shape)
    return i,j

ind = [find_indices(lon,lat,p*) for p in points]

ฉันค่อนข้างแน่ใจว่ามีวิธีแก้ปัญหาที่ดีกว่า (และเร็วกว่า) ใน numpy/scipy ฉันเคยไป Google ค่อนข้างมากแล้ว แต่คำตอบก็ยังหลบเลี่ยงฉันอยู่

มีข้อเสนอแนะใดเกี่ยวกับวิธีการค้นหาดัชนีของจุดที่เกี่ยวข้อง (ใกล้ที่สุด) ได้อย่างมีประสิทธิภาพ

ป.ล. คำถามนี้มาจากคำถามอื่น (คลิก)

แก้ไข: วิธีแก้ปัญหา

จากคำตอบของ @Cong Ma ฉันพบวิธีแก้ไขต่อไปนี้:

def find_indices(points,lon,lat,tree=None):
    if tree is None:
        lon,lat = lon.T,lat.T
        lonlat = np.column_stack((lon.ravel(),lat.ravel()))
        tree = sp.spatial.cKDTree(lonlat)
    dist,idx = tree.query(points,k=1)
    ind = np.column_stack(np.unravel_index(idx,lon.shape))
    return [(i,j) for i,j in ind]

เพื่อนำวิธีแก้ปัญหานี้และคำตอบจาก Divakar ไปสู่มุมมอง นี่คือการกำหนดเวลาบางส่วนของฟังก์ชันที่ฉันใช้ find_indices (และที่ซึ่งมันเป็นคอขวดในแง่ของความเร็ว) (ดูลิงก์ด้านบน):

spatial_contour_frequency/pil0                :   331.9553
spatial_contour_frequency/pil1                :   104.5771
spatial_contour_frequency/pil2                :     2.3629
spatial_contour_frequency/pil3                :     0.3287

pil0 เป็นแนวทางเริ่มต้นของฉัน pil1 Divakar's และ pil2/pil3 วิธีแก้ปัญหาสุดท้ายด้านบน โดยที่ทรีถูกสร้างขึ้นทันทีใน pil2 (เช่น สำหรับการวนซ้ำทุกครั้งของลูปที่มีการเรียก find_indices) และเพียงครั้งเดียวใน pil3 ( ดูรายละเอียดหัวข้ออื่น) แม้ว่าการปรับปรุงแนวทางเริ่มต้นของ Divakar จะทำให้ฉันมีความเร็วเพิ่มขึ้น 3 เท่า แต่ cKDTree ก็ยกระดับสิ่งนี้ไปอีกระดับด้วยการเร่งความเร็วอีก 50 เท่า! และการย้ายการสร้างทรีออกจากฟังก์ชันจะทำให้สิ่งต่างๆ เร็วขึ้นอีก


person flotzilla    schedule 02.10.2015    source แหล่งที่มา
comment
ในสคริปต์ของคุณ คุณกำลังสร้างแผนผังใหม่โดยมีการเรียก find_indices แต่ละครั้ง หากกริดของคุณได้รับการแก้ไขในการโทร คุณจะเสียเวลาในการสร้างต้นไม้ต้นเดียวกันซ้ำแล้วซ้ำอีก จริงๆ แล้ว การสร้างแผนผังนี้เป็นการเรียกที่ช้าที่สุดในฟังก์ชันนี้   -  person Cong Ma    schedule 05.10.2015
comment
ใช่ ฉันสังเกตเห็นว่านี่คือสิ่งที่ฉันกำลังทำอยู่ในขณะนี้ ;) ฉันจะอัปเดตคำตอบตามนั้น ขอบคุณสำหรับข้อสังเกต   -  person flotzilla    schedule 05.10.2015


คำตอบ (2)


หากคะแนนได้รับการแปลอย่างเพียงพอ คุณอาจลองใช้ scipy.spatial's cKDTree การใช้งานโดยตรง ตามที่อธิบายไว้ด้วยตัวเอง ในโพสต์อื่น โพสต์นั้นเกี่ยวกับการแก้ไข แต่คุณสามารถเพิกเฉยได้และใช้ส่วนแบบสอบถามเท่านั้น

tl; รุ่น dr:

อ่านเอกสารประกอบของ scipy.sptial.cKDTree< /ก>. สร้างทรีโดยส่งวัตถุ numpy ndarray ที่มีรูปร่าง (n, m) ไปยังเครื่องมือเริ่มต้น และทรีจะถูกสร้างขึ้นจากพิกัด n m มิติ

tree = scipy.spatial.cKDTree(array_of_coordinates)

หลังจากนั้น ใช้ tree.query() เพื่อเรียกเพื่อนบ้านที่ใกล้ที่สุดอันดับที่ k (อาจมีการประมาณและการขนาน โปรดดูเอกสาร) หรือใช้ tree.query_ball_point() เพื่อค้นหาเพื่อนบ้านทั้งหมดที่อยู่ในระยะที่ยอมรับได้

หากจุดต่างๆ ไม่ได้รับการแปลอย่างเหมาะสม และโทโพโลยีโค้งทรงกลม/ไม่สำคัญเกิดขึ้น คุณสามารถลองแยกท่อร่วมออกเป็นหลายส่วน โดยแต่ละส่วนมีขนาดเล็กพอที่จะพิจารณาว่าเป็นจุดท้องถิ่น

person Cong Ma    schedule 02.10.2015

ต่อไปนี้เป็นวิธีการแบบเวกเตอร์ทั่วไปโดยใช้ scipy.spatial.distance.cdist - -

import scipy

# Stack lon and lat arrays as columns to form a Nx2 array, where is N is grid**2
lonlat = np.column_stack((lon.ravel(),lat.ravel()))

# Get the distances and get the argmin across the entire N length
idx = scipy.spatial.distance.cdist(lonlat,points).argmin(0)

# Get the indices corresponding to grid's shape as the final output
ind = np.column_stack((np.unravel_index(idx,lon.shape))).tolist()

ตัวอย่างการวิ่ง -

In [161]: lon
Out[161]: 
array([[-11.   ,  -7.82 ,  -4.52 ,  -1.18 ,   2.19 ],
       [-12.   ,  -8.65 ,  -5.21 ,  -1.71 ,   1.81 ],
       [-13.   ,  -9.53 ,  -5.94 ,  -2.29 ,   1.41 ],
       [-14.1  ,  -0.04 ,  -6.74 ,  -2.91 ,   0.976]])

In [162]: lat
Out[162]: 
array([[-11.2  ,  -7.82 ,  -4.51 ,  -1.18 ,   2.19 ],
       [-12.   ,  -8.63 ,  -5.27 ,  -1.71 ,   1.81 ],
       [-13.2  ,  -9.52 ,  -5.96 ,  -2.29 ,   1.41 ],
       [-14.3  ,  -0.06 ,  -6.75 ,  -2.91 ,   0.973]])

In [163]: lonlat = np.column_stack((lon.ravel(),lat.ravel()))

In [164]: idx = scipy.spatial.distance.cdist(lonlat,points).argmin(0)

In [165]: np.column_stack((np.unravel_index(idx,lon.shape))).tolist()
Out[165]: [[0, 4], [0, 4], [0, 4], [0, 4], [0, 4], [0, 4], [3, 3]]

การทดสอบรันไทม์ -

กำหนดฟังก์ชัน:

def find_indices(lon,lat,x,y):
    lonlat = np.dstack([lon,lat])
    delta = np.abs(lonlat-[x,y])
    ij_1d = np.linalg.norm(delta,axis=2).argmin()
    i,j = np.unravel_index(ij_1d,lon.shape)
    return i,j

def loopy_app(lon,lat,pts):
    return [find_indices(lon,lat,pts[i,0],pts[i,1]) for i in range(pts.shape[0])]

def vectorized_app(lon,lat,points):
    lonlat = np.column_stack((lon.ravel(),lat.ravel()))
    idx = scipy.spatial.distance.cdist(lonlat,points).argmin(0)
    return np.column_stack((np.unravel_index(idx,lon.shape))).tolist()

การกำหนดเวลา:

In [179]: lon = np.random.rand(100,100)

In [180]: lat = np.random.rand(100,100)

In [181]: points = np.random.rand(50,2)

In [182]: %timeit loopy_app(lon,lat,points)
10 loops, best of 3: 47 ms per loop

In [183]: %timeit vectorized_app(lon,lat,points)
10 loops, best of 3: 16.6 ms per loop

เพื่อเพิ่มประสิทธิภาพการทำงานให้มากขึ้น np.concatenate สามารถใช้ใน สถานที่ของ np.column_stack

person Divakar    schedule 03.10.2015