K najbliższych sąsiadów

Prosty algorytm klasyfikujący nowy przykład na podstawie najczęstszej klasy wśród jego K najbliższych sąsiadów w danych. Nie wymaga osobnego etapu treningu.

K najbliższych sąsiadów (K-Nearest Neighbors, KNN) to jeden z najprostszych algorytmów uczenia maszynowego: żeby zaklasyfikować nowy przykład, patrzysz na K najbliższych mu punktów w danych treningowych i przypisujesz mu klasę, która wśród nich występuje najczęściej (zwykłe głosowanie większościowe). Cały „model” to po prostu zapamiętany zbiór treningowy — dlatego mówi się o KNN jako o lazy learning: nie ma osobnego, kosztownego etapu uczenia, cała praca dzieje się dopiero w momencie predykcji.

Jak to działa

Mechanika jest brutalnie prosta. Dla nowego punktu liczysz jego odległość do każdego punktu w zbiorze (najczęściej metryką euklidesową, czasem Manhattan albo kosinusową), sortujesz, bierzesz K najbliższych i sprawdzasz, jaka klasa dominuje. Przy K=1 kopiujesz etykietę najbliższego sąsiada, przy większym K uśredniasz decyzję po sąsiedztwie. KNN robi też regresję — wtedy zamiast głosowania bierzesz średnią (lub ważoną średnią) wartości sąsiadów.

Nie ma tu żadnych „wag”, które algorytm by trenował — wynik zależy wyłącznie od danych, wybranej metryki i wartości K. To zaleta (zero założeń o kształcie granicy decyzyjnej) i przekleństwo naraz, bo każda predykcja wymaga przejrzenia całego zbioru.

Przykład z praktyki

W scikit-learn robisz to dosłownie w trzech linijkach:

  1. from sklearn.neighbors import KNeighborsClassifier
  2. model = KNeighborsClassifier(n_neighbors=5)
  3. model.fit(X_train, y_train) i potem model.predict(X_test)

Klasyczny scenariusz: masz dane irysów albo zbiór odręcznych cyfr (MNIST) i chcesz szybki baseline, zanim sięgniesz po cięższe modele. KNN świetnie się do tego nadaje — w 5 minut masz punkt odniesienia, z którym porównasz każdy kolejny model. Ten sam algorytm napędza też proste systemy rekomendacji („użytkownicy podobni do ciebie”) i wyszukiwanie podobieństwa.

Częste błędy i mity

  • Brak skalowania cech. Jeśli jedna cecha jest w tysiącach, a druga w ułamkach, odległość zdominuje ta pierwsza. Zawsze rób StandardScaler albo normalizację przed KNN.
  • Parzyste K przy dwóch klasach. Dostaniesz remisy w głosowaniu. Trzymaj się nieparzystych wartości.
  • „KNN się trenuje”. Nie — fit() w zasadzie tylko zapamiętuje dane. Cały koszt to predykcja, która przy dużych zbiorach jest wolna (klątwa wymiarowości też nie pomaga).
  • Mylenie z K-means. KNN to klasyfikacja z nadzorem; K-means to klasteryzacja bez nadzoru. Wspólne jest tylko „K” w nazwie.

Pojęcia powiązane

Warto kojarzyć: metryki odległości (euklidesowa, Manhattan, kosinusowa), normalizacja i standaryzacja cech, klątwa wymiarowości, walidacja krzyżowa (do doboru K), drzewa KD-tree i ball-tree (przyspieszają szukanie sąsiadów) oraz K-means, z którym KNN bywa mylony.