Recherche de site Web

Une introduction douce à la rétropropagation dans le temps


La rétropropagation à travers le temps, ou BPTT, est l'algorithme d'entraînement utilisé pour mettre à jour les poids dans les réseaux neuronaux récurrents comme les LSTM.

Pour cadrer efficacement les problèmes de prédiction de séquence pour les réseaux de neurones récurrents, vous devez avoir une solide compréhension conceptuelle de ce que fait la rétropropagation dans le temps et de la manière dont des variations configurables telles que la rétropropagation tronquée dans le temps affecteront les compétences, la stabilité et la vitesse lors de la formation de votre réseau. post, vous obtiendrez une introduction douce à la rétropropagation à travers le temps destinée au praticien (pas d'équations !).

Dans cet article, vous obtiendrez une introduction douce à la rétropropagation à travers le temps destinée au praticien (pas d'équations !).

Après avoir lu cet article, vous saurez :

  • Qu'est-ce que la rétropropagation dans le temps et comment elle se rapporte à l'algorithme d'entraînement de rétropropagation utilisé par les réseaux Perceptron multicouches.
  • Les motivations qui conduisent à la nécessité d'une rétropropagation tronquée dans le temps, la variante la plus largement utilisée en apprentissage profond pour la formation des LSTM.
  • Une notation pour réfléchir à la façon de configurer la rétropropagation tronquée dans le temps et aux configurations canoniques utilisées dans la recherche et par les bibliothèques d'apprentissage profond.

Démarrez votre projet avec mon nouveau livre Long Short-Term Memory Networks With Python, comprenant des tutoriels étape par étape et le code source Python fichiers pour tous les exemples.

Commençons.

Algorithme de formation à la rétropropagation

La rétropropagation fait référence à deux choses :

  • Méthode mathématique utilisée pour calculer les dérivées et application de la règle de la chaîne dérivée.
  • L'algorithme de formation pour mettre à jour les pondérations du réseau afin de minimiser les erreurs.

C’est cette dernière compréhension de la rétropropagation que nous utilisons ici.

L'objectif de l'algorithme d'entraînement par rétropropagation est de modifier les poids d'un réseau neuronal afin de minimiser l'erreur des sorties du réseau par rapport à certaines sorties attendues en réponse aux entrées correspondantes.

Il s'agit d'un algorithme d'apprentissage supervisé qui permet de corriger le réseau au regard des erreurs spécifiques commises.

L'algorithme général est le suivant :

  1. Présentez un modèle d'entrée de formation et propagez-le à travers le réseau pour obtenir un résultat.
  2. Comparez les résultats prédits aux résultats attendus et calculez l’erreur.
  3. Calculez les dérivées de l'erreur par rapport aux poids du réseau.
  4. Ajustez les poids pour minimiser l’erreur.
  5. Répéter.

Pour en savoir plus sur la rétropropagation, consultez l'article :

  • Comment implémenter l'algorithme de rétropropagation à partir de zéro en Python

L'algorithme d'entraînement par rétropropagation convient à l'entraînement de réseaux neuronaux à action directe sur des paires entrée-sortie de taille fixe, mais qu'en est-il des données de séquence qui peuvent être ordonnées temporellement ?

Rétropropagation dans le temps

La rétropropagation à travers le temps, ou BPTT, est l'application de l'algorithme d'entraînement de rétropropagation au réseau neuronal récurrent appliqué pour séquencer des données comme une série chronologique.

Un réseau neuronal récurrent affiche une entrée à chaque pas de temps et prédit une sortie.

Conceptuellement, BPTT fonctionne en déroulant tous les pas de temps d'entrée. Chaque pas de temps a un pas de temps d'entrée, une copie du réseau et une sortie. Les erreurs sont ensuite calculées et accumulées pour chaque pas de temps. Le réseau est reconstitué et les poids sont mis à jour.

Spatialement, chaque pas de temps du réseau neuronal récurrent déroulé peut être considéré comme une couche supplémentaire étant donné la dépendance en termes d'ordre du problème et l'état interne du pas de temps précédent est pris comme entrée pour le pas de temps suivant.

Nous pouvons résumer l’algorithme comme suit :

  1. Présentez une séquence de pas de temps de paires d’entrée et de sortie au réseau.
  2. Déroulez le réseau, puis calculez et accumulez les erreurs à chaque pas de temps.
  3. Regroupez le réseau et mettez à jour les poids.
  4. Répéter.

BPTT peut être coûteux en calcul à mesure que le nombre de pas de temps augmente.

Si les séquences d'entrée sont composées de milliers de pas de temps, il s'agira alors du nombre de dérivées requis pour une seule mise à jour du poids de mise à jour. Cela peut provoquer la disparition ou l'explosion des poids (passer à zéro ou déborder) et rendre bruyant l'apprentissage lent et les compétences de modélisation.

Rétropropagation tronquée dans le temps

La rétropropagation tronquée dans le temps, ou TBPTT, est une version modifiée de l'algorithme d'entraînement BPTT pour les réseaux de neurones récurrents où la séquence est traitée un pas de temps à la fois et périodiquement (k1 pas de temps), la mise à jour BPTT est effectuée pour un nombre fixe de pas de temps ( k2 pas de temps).

Ilya Sutskever le dit clairement dans sa thèse :

La rétropropagation tronquée est sans doute la méthode la plus pratique pour former des RNN.

L’un des principaux problèmes du BPTT est le coût élevé d’une mise à jour d’un seul paramètre, ce qui rend impossible l’utilisation d’un grand nombre d’itérations.

Le coût peut être réduit avec une méthode naïve qui divise la séquence de 1 000 longueurs en 50 séquences (disons) chacune de longueur 20 et traite chaque séquence de longueur 20 comme un cas de formation distinct. Il s’agit d’une approche sensée qui peut bien fonctionner dans la pratique, mais elle ne tient pas compte des dépendances temporelles qui s’étendent sur plus de 20 pas de temps.

Le BPTT tronqué est une méthode étroitement liée. Il traite la séquence un pas de temps à la fois, et à chaque pas de temps k1, il exécute BPTT pour les pas de temps k2, donc une mise à jour des paramètres peut être bon marché si k2 est petit. Par conséquent, ses états cachés ont été exposés à de nombreux intervalles de temps et peuvent donc contenir des informations utiles sur un passé lointain, qui pourraient être exploitées de manière opportuniste.

— Ilya Sutskever, Formation aux réseaux de neurones récurrents, thèse, 2013

Nous pouvons résumer l’algorithme comme suit :

  1. Présentez une séquence de k1 pas de temps de paires d’entrée et de sortie au réseau.
  2. Déroulez le réseau, puis calculez et accumulez les erreurs sur des pas de temps k2.
  3. Regroupez le réseau et mettez à jour les poids.
  4. Répéter

L'algorithme TBPTT nécessite la prise en compte de deux paramètres :

  • k1 : nombre de pas de temps de passage en avant entre les mises à jour. Généralement, cela influence la lenteur ou la rapidité de l'entraînement, compte tenu de la fréquence à laquelle les mises à jour de poids sont effectuées.
  • k2 : nombre de pas de temps auxquels appliquer BPTT. Généralement, il doit être suffisamment grand pour capturer la structure temporelle du problème afin que le réseau puisse l'apprendre. Une valeur trop élevée entraîne la disparition des dégradés.

Pour que ce soit plus clair :

… on peut utiliser une approximation historique limitée dans laquelle les informations pertinentes sont enregistrées pendant un nombre fixe h de pas de temps et toute information plus ancienne que cela est oubliée. En général, cela doit être considéré comme une technique heuristique destinée à simplifier le calcul, même si, comme indiqué ci-dessous, elle peut parfois servir d'approximation adéquate du gradient réel et peut également être plus appropriée dans les situations où les poids sont ajustés en fonction du réseau. court. Appelons cet algorithme la rétropropagation tronquée dans le temps. Avec h représentant le nombre de pas de temps antérieurs enregistrés, cet algorithme sera noté BPTT(h).

Notez que dans BPTT(h), un passage en arrière par les h pas de temps les plus récents est effectué à nouveau chaque fois que le réseau passe par un pas de temps supplémentaire. Pour généraliser cela, on peut envisager de laisser le réseau parcourir h0 pas de temps supplémentaires avant d'effectuer le prochain calcul BPTT, où h0 <= h.

La caractéristique clé de cet algorithme est que le prochain passage en arrière n'est effectué qu'au pas de temps t + h0 ; dans l'intervalle, l'historique des entrées réseau, l'état du réseau et les valeurs cibles sont enregistrés dans le tampon historique, mais aucun traitement n'est effectué sur ces données. Notons cet algorithme BPTT(h; h0). Clairement, BPTT(h) est identique à BPTT(h; 1), et BPTT(h; h) est l'algorithme BPTT par époque.

— Ronald J. Williams et Jing Peng, Un algorithme efficace basé sur un gradient pour la formation en ligne des trajectoires de réseau récurrentes, 1990

Nous pouvons emprunter la notation à Williams et Peng et faire référence aux différentes configurations TBPTT par TBPTT(k1, k2).

En utilisant cette notation, nous pouvons définir quelques approches standards ou courantes :

Notez qu'ici n fait référence au nombre total de pas de temps dans la séquence :

  • TBPTT(n,n) : les mises à jour sont effectuées à la fin de la séquence sur tous les pas de temps de la séquence (par exemple, BPTT classique).
  • TBPTT(1,n) : les pas de temps sont traités un par un, suivis d'une mise à jour qui couvre tous les pas de temps vus jusqu'à présent (par exemple, le TBPTT classique de Williams et Peng).
  • TBPTT(k1,1) : le réseau ne dispose probablement pas de suffisamment de contexte temporel pour apprendre, et s'appuie fortement sur l'état et les entrées internes.
  • TBPTT(k1,k2), où k1 : plusieurs mises à jour sont effectuées par séquence, ce qui peut accélérer l'entraînement.
  • TBPTT(k1,k2), où k1=k2 : configuration courante dans laquelle un nombre fixe de pas de temps est utilisé pour les pas de temps avant et arrière (par exemple, 10 s à 100 s).

Nous pouvons voir que toutes les configurations sont une variation de TBPTT(n,n) qui tentent essentiellement de se rapprocher du même effet avec peut-être un entraînement plus rapide et des résultats plus stables.

Le TBPTT canonique rapporté dans les articles peut être considéré comme TBPTT(k1,k2), où k1=k2=h et h<=n, et où le paramètre choisi est petit (des dizaines à des centaines de pas de temps).

Dans des bibliothèques comme TensorFlow et Keras, les choses se ressemblent et h définit la longueur fixe vectorisée des pas de temps des données préparées.

Lectures complémentaires

Cette section fournit quelques ressources pour une lecture plus approfondie.

Livres

  • Neural Smithing : apprentissage supervisé dans les réseaux de neurones artificiels à action directe, 1999
  • Apprentissage profond, 2016

Papiers

  • Apprentissage des représentations par rétro-propagation d'erreurs, 1986
  • Rétropropagation dans le temps : ce qu'elle fait et comment la faire, 1990
  • Un algorithme efficace basé sur le gradient pour la formation en ligne des trajectoires de réseau récurrentes, 1990
  • Formation aux réseaux de neurones récurrents, Thèse, 2013
  • Algorithmes d'apprentissage basés sur le gradient pour les réseaux récurrents et leur complexité informatique, 1995

Articles

  • Rétropropagation sur Wikipédia
  • Rétropropagation dans le temps sur Wikipédia
  • Styles de rétropropagation tronquée
  • Réponse à la question « RNN : quand appliquer le BPTT et/ou mettre à jour les pondérations ? » sur CrossValidated

Résumé

Dans cet article, vous avez découvert la rétropropagation dans le temps pour entraîner les réseaux de neurones récurrents.

Concrètement, vous avez appris :

  • Qu'est-ce que la rétropropagation dans le temps et comment elle se rapporte à l'algorithme d'entraînement de rétropropagation utilisé par les réseaux Perceptron multicouches.
  • Les motivations qui conduisent à la nécessité d'une rétropropagation tronquée dans le temps, la variante la plus largement utilisée en apprentissage profond pour la formation des LSTM.
  • Une notation pour réfléchir à la façon de configurer la rétropropagation tronquée dans le temps et aux configurations canoniques utilisées dans la recherche et par les bibliothèques d'apprentissage profond.

Avez-vous des questions sur la rétropropagation dans le temps ?
Posez vos questions dans les commentaires ci-dessous et je ferai de mon mieux pour y répondre.

Articles connexes