Com utilitzar el mètode 'torch.argmax()' a PyTorch?

Com Utilitzar El Metode Torch Argmax A Pytorch



A PyTorch, el ' torch.argmax() El mètode ” és una funció integrada que retorna índexs dels valors màxims d'un tensor particular en una dimensió determinada. Els usuaris utilitzen aquesta funció quan treballen amb tensors i volen trobar l'índex del valor màxim al llarg de la dimensió donada d'un tensor. A més, aquest mètode també pot ser útil per a la classificació on els usuaris volen saber quina classe té la probabilitat més alta.

Aquest bloc exemplificarà el mètode per utilitzar el mètode 'torch.argmax()' a PyTorch.

Com utilitzar el mètode 'torch.argmax()' a PyTorch?

El mètode 'torch.argmax()' pren qualsevol tensor 1D o 2D com a entrada i retorna un tensor que conté els índexs/índexs dels valors màxims al llarg de la dimensió donada.







La sintaxi del mètode 'torch.argmax()' es mostra a continuació:



torxa. argmax ( < tensor_entrada > )

Per utilitzar aquest mètode a PyTorch, seguiu els exemples següents per entendre-ho millor:



Exemple 1: utilitzeu el mètode 'torch.argmax()' amb tensor 1D

En el primer exemple, crearem un tensor 1D i utilitzarem el mètode 'torch.argmax()' amb ell. Seguim el següent procediment pas a pas:





Pas 1: importa la biblioteca PyTorch

Primer, importeu el ' torxa ” per utilitzar el mètode “torch.argmax()”:

importar torxa

Pas 2: creeu un tensor 1D

A continuació, creeu un tensor 1D i imprimiu els seus elements. Aquí, estem creant el següent ' Desenes 1 ' tensor d'una llista utilitzant el ' torch.tensor() ” funció:



Desenes 1 = torxa. tensor ( [ 5 , 0 , - 8 , 1 , 9 , 7 ] )

imprimir ( Desenes 1 )

Això ha creat un tensor 1D com es veu a continuació:

Pas 3: Trobeu índexs de valor màxim

Ara, utilitzeu el ' torch.argmax() ” funció per trobar l'índex/índexs del valor màxim a la “ Desenes 1 ” tensor:

T1_ind = torxa. argmax ( Desenes 1 )

Pas 4: imprimiu l'índex del valor màxim

Finalment, mostra l'índex del valor màxim al tensor d'entrada:

imprimir ( 'Índexs:' , T1_ind )

La sortida següent mostra l'índex del valor màxim en el ' Desenes 1 ' tensor, és a dir, 4. Significa que el valor més alt del tensor es troba al quart índex que és ' 9 ”:

Exemple 2: utilitzeu el mètode 'torch.argmax()' amb tensor 2D

En el segon exemple, crearem un tensor 2D i utilitzarem el mètode 'torch.argmax()' amb ell. Seguim els passos indicats:

Pas 1: importa la biblioteca PyTorch

Primer, importeu el ' torxa ” per utilitzar el mètode “torch.argmax()”:

importar torxa

Pas 2: creeu un tensor 2D

A continuació, utilitzeu el ' torch.tensor() ” per crear un tensor 2D i imprimir els seus elements. Aquí, estem creant el següent ' Desenes 2 “Tensor 2D:

Desenes 2 = torxa. tensor ( [ [ 4 , 1 , - 7 ] , [ 15 , 6 , 0 ] , [ - 7 , 9 , 2 ] ] )

imprimir ( Desenes 2 )

Això ha creat un tensor 2D com es veu a continuació:

Pas 3: Trobeu índexs de valor màxim

Ara, cerqueu l'índex del valor màxim a la ' Desenes 2 ' tensor utilitzant el ' torch.argmax() ” funció:

T2_ind = torxa. argmax ( Desenes 2 )

Pas 4: imprimiu l'índex del valor màxim

Finalment, visualitzeu l'índex del valor màxim al tensor d'entrada:

imprimir ( 'Índexs:' , T2_ind )

Segons la sortida següent, l'índex del valor màxim en el ' Desenes 2 'el tensor és '3'. Vol dir que el valor més alt del tensor es troba al tercer índex que és ' 15 ”:

Pas 5: cerqueu índexs de valor màxim al llarg de les columnes

A més, els usuaris també poden trobar els índexs/índexs dels valors màxims al llarg de cada columna d'un tensor. Per exemple, podem utilitzar el ' dim=0 ” argument amb la funció “torch.argmax()”. Troba els índexs dels valors màxims al llarg de les columnes del ' Desenes 2 ” tensor i després imprimeix aquests índexs:

col_index = torxa. argmax ( Desenes 2 , dim = 0 )

imprimir ( 'Índexs en columnes:' , col_index )

La sortida següent mostra els índexs dels valors màxims al llarg de cada columna del tensor:

Pas 6: cerqueu índexs de valor màxim al llarg de les files

De la mateixa manera, els usuaris també poden trobar els índexs/índexs dels valors màxims al llarg de cada fila d'un tensor. Per exemple, utilitzeu el ' dim=1 ” amb la funció “torch.argmax()” per trobar els índexs dels valors màxims al llarg de les files del tensor “Tens2” i després imprimir aquests índexs:

indice_row = torxa. argmax ( Desenes 2 , dim = 1 )

imprimir ( 'Índexs en files:' , indice_row )

Els índexs del valor màxim al llarg de cada fila d'un tensor 'Tens2' es poden veure a continuació:

Hem explicat de manera eficient el mètode per utilitzar el mètode 'torch.argmax()' a PyTorch.

Nota : Podeu accedir al nostre quadern Google Colab aquí enllaç .

Conclusió

Per utilitzar el mètode 'torch.argmax()' a PyTorch, primer, importeu el ' torxa ” biblioteca. A continuació, creeu el tensor 1D o 2D desitjat i visualitzeu-ne els elements. A continuació, utilitzeu ' torch.argmax() ” mètode per trobar/calcular els índexs/índexs dels valors màxims del tensor. A més, els usuaris també poden trobar els índexs del valor màxim al llarg de cada fila o columna del tensor mitjançant el ' dim ” argument. Finalment, visualitzeu l'índex del valor màxim al tensor d'entrada. Aquest bloc ha exemplificat el mètode per utilitzar el mètode 'torch.argmax()' a PyTorch.