Recherche de site Web

Comment implémenter l'algorithme d'arbre de décision à partir de zéro en Python


Les arbres de décision sont une méthode de prédiction puissante et extrêmement populaire.

Ils sont populaires parce que le modèle final est si facile à comprendre aussi bien par les praticiens que par les experts du domaine. L'arbre de décision final peut expliquer exactement pourquoi une prédiction spécifique a été faite, ce qui le rend très attractif pour une utilisation opérationnelle.

Les arbres de décision constituent également la base de méthodes d'ensemble plus avancées telles que le bagging, les forêts aléatoires et l'augmentation du gradient.

Dans ce didacticiel, vous découvrirez comment implémenter l'algorithme d'arbre de classification et de régression à partir de zéro avec Python.

Après avoir terminé ce tutoriel, vous saurez :

  • Comment calculer et évaluer les points de partage candidats dans des données.
  • Comment organiser les divisions dans une structure d'arbre de décision.
  • Comment appliquer l'algorithme de l'arbre de classification et de régression à un problème réel.

Démarrez votre projet avec mon nouveau livre Machine Learning Algorithms From Scratch, comprenant des tutoriels pas à pas et les fichiers code source Python pour tous les exemples.

Commençons.

  • Mise à jour janvier 2017 : modification du calcul de Fold_size dans cross_validation_split() pour qu'il soit toujours un nombre entier. Résout les problèmes avec Python 3.
  • Mise à jour de février 2017 : correction d'un bug dans build_tree.
  • Mise à jour d'août 2017 : Correction d'un bug dans le calcul de Gini, ajout de la pondération manquante des scores de Gini des groupes par taille de groupe (merci Michael !).
  • Mise à jour d'août 2018 : testée et mise à jour pour fonctionner avec Python 3.6.

Descriptions

Cette section fournit une brève introduction à l'algorithme de l'arbre de classification et de régression et à l'ensemble de données Banknote utilisé dans ce didacticiel.

Arbres de classification et de régression

Les arbres de classification et de régression ou CART en abrégé sont un acronyme introduit par Leo Breiman pour désigner les algorithmes d'arbre de décision qui peuvent être utilisés pour des problèmes de modélisation prédictive de classification ou de régression.

Nous nous concentrerons sur l'utilisation de CART pour la classification dans ce didacticiel.

La représentation du modèle CART est un arbre binaire. Il s'agit du même arbre binaire issu d'algorithmes et de structures de données, rien de trop sophistiqué (chaque nœud peut avoir zéro, un ou deux nœuds enfants).

Un nœud représente une variable d'entrée unique (X) et un point de partage sur cette variable, en supposant que la variable est numérique. Les nœuds feuilles (également appelés nœuds terminaux) de l'arbre contiennent une variable de sortie (y) qui est utilisée pour faire une prédiction.

Une fois créé, un arbre peut être parcouru avec une nouvelle ligne de données suivant chaque branche avec les divisions jusqu'à ce qu'une prédiction finale soit faite.

La création d'un arbre de décision binaire est en fait un processus de division de l'espace d'entrée. Une approche gourmande est utilisée pour diviser l'espace appelée division binaire récursive. Il s'agit d'une procédure numérique dans laquelle toutes les valeurs sont alignées et différents points de partage sont essayés et testés à l'aide d'une fonction de coût.

La répartition présentant le meilleur coût (le coût le plus bas car nous minimisons le coût) est sélectionnée. Toutes les variables d'entrée et tous les points de partage possibles sont évalués et choisis de manière gourmande en fonction de la fonction de coût.

  • Régression : la fonction de coût minimisée pour choisir les points de partage est la somme des erreurs quadratiques sur tous les échantillons d'apprentissage qui se trouvent dans le rectangle.
  • Classification : la fonction de coût Gini est utilisée pour fournir une indication du degré de pureté des nœuds, où la pureté des nœuds fait référence à la manière dont les données d'entraînement attribuées à chaque nœud sont mélangées.

Le fractionnement se poursuit jusqu'à ce que les nœuds contiennent un nombre minimum d'exemples de formation ou qu'une profondeur d'arborescence maximale soit atteinte.

Ensemble de données sur les billets de banque

L'ensemble de données sur les billets de banque consiste à prédire si un billet de banque donné est authentique à partir d'un certain nombre de mesures prises à partir d'une photographie.

L'ensemble de données contient 1 372 lignes avec 5 variables numériques. Il s'agit d'un problème de classification à deux classes (classification binaire).

Vous trouverez ci-dessous une liste des cinq variables de l'ensemble de données.

  1. variance de l'image transformée en ondelettes (continue).
  2. asymétrie de l'image transformée en ondelettes (continue).
  3. aplatissement de l'image transformée en ondelettes (continue).
  4. entropie de l'image (continue).
  5. classe (entier).

Vous trouverez ci-dessous un échantillon des 5 premières lignes de l'ensemble de données

3.6216,8.6661,-2.8073,-0.44699,0
4.5459,8.1674,-2.4586,-1.4621,0
3.866,-2.6383,1.9242,0.10645,0
3.4566,9.5228,-4.0112,-3.5944,0
0.32924,-4.4552,4.5718,-0.9888,0
4.3684,9.6718,-3.9606,-3.1625,0

En utilisant l'algorithme de la règle zéro pour prédire la valeur de classe la plus courante, la précision de base du problème est d'environ 50 %.

Vous pouvez en savoir plus et télécharger l’ensemble de données à partir du référentiel UCI Machine Learning.

Téléchargez l'ensemble de données et placez-le dans votre répertoire de travail actuel sous le nom de fichier data_banknote_authentication.csv.

Tutoriel

Ce tutoriel se décompose en 5 parties :

  1. Indice de Gini.
  2. Créez un fractionnement.
  3. Construisez un arbre.
  4. Faites une prédiction.
  5. Étude de cas sur les billets de banque.

Ces étapes vous donneront les bases dont vous avez besoin pour implémenter l'algorithme CART à partir de zéro et l'appliquer à vos propres problèmes de modélisation prédictive.

1. Indice de Gini

L'indice de Gini est le nom de la fonction de coût utilisée pour évaluer les fractionnements dans l'ensemble de données.

Une division dans l'ensemble de données implique un attribut d'entrée et une valeur pour cet attribut. Il peut être utilisé pour diviser les modèles d’entraînement en deux groupes de lignes.

Un score de Gini donne une idée de la qualité d'une répartition en fonction de la mixité des classes dans les deux groupes créés par la répartition. Une séparation parfaite donne un score de Gini de 0, tandis que le pire des cas, qui donne lieu à des classes 50/50 dans chaque groupe, donne un score de Gini de 0,5 (pour un problème à 2 classes).

Le calcul de Gini est mieux démontré avec un exemple.

Nous avons deux groupes de données avec 2 lignes dans chaque groupe. Les lignes du premier groupe appartiennent toutes à la classe 0 et les lignes du deuxième groupe appartiennent à la classe 1, c'est donc une répartition parfaite.

Nous devons d’abord calculer la proportion de classes dans chaque groupe.

proportion = count(class_value) / count(rows)

Les proportions pour cet exemple seraient :

group_1_class_0 = 2 / 2 = 1
group_1_class_1 = 0 / 2 = 0
group_2_class_0 = 0 / 2 = 0
group_2_class_1 = 2 / 2 = 1

Gini est ensuite calculé pour chaque nœud enfant comme suit :

gini_index = sum(proportion * (1.0 - proportion))
gini_index = 1.0 - sum(proportion * proportion)

L'indice de Gini pour chaque groupe doit ensuite être pondéré en fonction de la taille du groupe, par rapport à tous les échantillons du parent, par ex. tous les échantillons actuellement regroupés. On peut ajouter cette pondération au calcul de Gini pour un groupe comme suit :

gini_index = (1.0 - sum(proportion * proportion)) * (group_size/total_samples)

Dans cet exemple, les scores de Gini pour chaque groupe sont calculés comme suit :

Gini(group_1) = (1 - (1*1 + 0*0)) * 2/4
Gini(group_1) = 0.0 * 0.5 
Gini(group_1) = 0.0 
Gini(group_2) = (1 - (0*0 + 1*1)) * 2/4
Gini(group_2) = 0.0 * 0.5 
Gini(group_2) = 0.0

Les scores sont ensuite ajoutés sur chaque nœud enfant au point de partage pour donner un score Gini final pour le point de partage qui peut être comparé à d'autres points de partage candidats.

Le Gini pour ce point de partage serait alors calculé comme 0,0 + 0,0 ou un score de Gini parfait de 0,0.

Vous trouverez ci-dessous une fonction nommée gini_index() qui calcule l'indice de Gini pour une liste de groupes et une liste de valeurs de classe connues.

Vous pouvez voir qu'il y a quelques contrôles de sécurité pour éviter une division par zéro pour un groupe vide.

# Calculate the Gini index for a split dataset
def gini_index(groups, classes):
	# count all samples at split point
	n_instances = float(sum([len(group) for group in groups]))
	# sum weighted Gini index for each group
	gini = 0.0
	for group in groups:
		size = float(len(group))
		# avoid divide by zero
		if size == 0:
			continue
		score = 0.0
		# score the group based on the score for each class
		for class_val in classes:
			p = [row[-1] for row in group].count(class_val) / size
			score += p * p
		# weight the group score by its relative size
		gini += (1.0 - score) * (size / n_instances)
	return gini

Nous pouvons tester cette fonction avec notre exemple concret ci-dessus. Nous pouvons également le tester pour le pire des cas d’une répartition 50/50 dans chaque groupe. L’exemple complet est répertorié ci-dessous.

# Calculate the Gini index for a split dataset
def gini_index(groups, classes):
	# count all samples at split point
	n_instances = float(sum([len(group) for group in groups]))
	# sum weighted Gini index for each group
	gini = 0.0
	for group in groups:
		size = float(len(group))
		# avoid divide by zero
		if size == 0:
			continue
		score = 0.0
		# score the group based on the score for each class
		for class_val in classes:
			p = [row[-1] for row in group].count(class_val) / size
			score += p * p
		# weight the group score by its relative size
		gini += (1.0 - score) * (size / n_instances)
	return gini

# test Gini values
print(gini_index([[[1, 1], [1, 0]], [[1, 1], [1, 0]]], [0, 1]))
print(gini_index([[[1, 0], [1, 0]], [[1, 1], [1, 1]]], [0, 1]))

L'exécution de l'exemple imprime les deux scores de Gini, d'abord le score du pire des cas à 0,5, suivi du score du meilleur des cas à 0,0.

0.5
0.0

Maintenant que nous savons comment évaluer les résultats d’une scission, examinons la création de divisions.

2. Créer une division

Une division est composée d'un attribut dans l'ensemble de données et d'une valeur.

Nous pouvons résumer cela comme l'index d'un attribut à diviser et la valeur par laquelle diviser les lignes sur cet attribut. Il s'agit simplement d'un raccourci utile pour l'indexation en lignes de données.

La création d'une division implique trois parties, la première que nous avons déjà examinée est le calcul du score de Gini. Les deux parties restantes sont :

  1. Fractionner un ensemble de données.
  2. Évaluation de tous les fractionnements.

Jetons un coup d'œil à chacun.

2.1. Fractionner un ensemble de données

Diviser un ensemble de données signifie séparer un ensemble de données en deux listes de lignes en fonction de l'index d'un attribut et d'une valeur divisée pour cet attribut.

Une fois que nous avons les deux groupes, nous pouvons alors utiliser notre score Gini ci-dessus pour évaluer le coût de la répartition.

Le fractionnement d'un ensemble de données implique de parcourir chaque ligne, de vérifier si la valeur de l'attribut est inférieure ou supérieure à la valeur divisée et de l'attribuer respectivement au groupe de gauche ou de droite.

Vous trouverez ci-dessous une fonction nommée test_split() qui implémente cette procédure.

# Split a dataset based on an attribute and an attribute value
def test_split(index, value, dataset):
	left, right = list(), list()
	for row in dataset:
		if row[index] < value:
			left.append(row)
		else:
			right.append(row)
	return left, right

Pas grand-chose.

Notez que le groupe de droite contient toutes les lignes dont la valeur à l'index est supérieure ou égale à la valeur de fractionnement.

2.2. Évaluation de toutes les divisions

Avec la fonction Gini ci-dessus et la fonction test split, nous avons désormais tout ce dont nous avons besoin pour évaluer les splits.

Étant donné un ensemble de données, nous devons vérifier chaque valeur de chaque attribut en tant que répartition candidate, évaluer le coût de la répartition et trouver la meilleure répartition possible.

Une fois la meilleure répartition trouvée, nous pouvons l’utiliser comme nœud dans notre arbre de décision.

Il s’agit d’un algorithme exhaustif et gourmand.

Nous utiliserons un dictionnaire pour représenter un nœud dans l'arbre de décision car nous pouvons stocker les données par nom. Lors de la sélection de la meilleure répartition et de son utilisation comme nouveau nœud pour l'arborescence, nous stockerons l'index de l'attribut choisi, la valeur de cet attribut par lequel diviser et les deux groupes de données divisés par le point de division choisi.

Chaque groupe de données est son propre petit ensemble de données composé uniquement des lignes attribuées au groupe de gauche ou de droite par le processus de fractionnement. Vous pouvez imaginer comment nous pourrions à nouveau diviser chaque groupe, de manière récursive, à mesure que nous construisons notre arbre de décision.

Vous trouverez ci-dessous une fonction nommée get_split() qui implémente cette procédure. Vous pouvez voir qu'il parcourt chaque attribut (à l'exception de la valeur de classe), puis chaque valeur de cet attribut, en divisant et en évaluant les divisions au fur et à mesure.

La meilleure répartition est enregistrée puis renvoyée une fois toutes les vérifications terminées.

# Select the best split point for a dataset
def get_split(dataset):
	class_values = list(set(row[-1] for row in dataset))
	b_index, b_value, b_score, b_groups = 999, 999, 999, None
	for index in range(len(dataset[0])-1):
		for row in dataset:
			groups = test_split(index, row[index], dataset)
			gini = gini_index(groups, class_values)
			if gini < b_score:
				b_index, b_value, b_score, b_groups = index, row[index], gini, groups
	return {'index':b_index, 'value':b_value, 'groups':b_groups}

Nous pouvons créer un petit ensemble de données pour tester cette fonction et l'ensemble de notre processus de fractionnement d'ensemble de données.

X1			X2			Y
2.771244718		1.784783929		0
1.728571309		1.169761413		0
3.678319846		2.81281357		0
3.961043357		2.61995032		0
2.999208922		2.209014212		0
7.497545867		3.162953546		1
9.00220326		3.339047188		1
7.444542326		0.476683375		1
10.12493903		3.234550982		1
6.642287351		3.319983761		1

Nous pouvons tracer cet ensemble de données en utilisant des couleurs distinctes pour chaque classe. Vous pouvez voir qu'il ne serait pas difficile de choisir manuellement une valeur de X1 (axe des x sur le tracé) pour diviser cet ensemble de données.

L’exemple ci-dessous rassemble tout cela.

# Split a dataset based on an attribute and an attribute value
def test_split(index, value, dataset):
	left, right = list(), list()
	for row in dataset:
		if row[index] < value:
			left.append(row)
		else:
			right.append(row)
	return left, right

# Calculate the Gini index for a split dataset
def gini_index(groups, classes):
	# count all samples at split point
	n_instances = float(sum([len(group) for group in groups]))
	# sum weighted Gini index for each group
	gini = 0.0
	for group in groups:
		size = float(len(group))
		# avoid divide by zero
		if size == 0:
			continue
		score = 0.0
		# score the group based on the score for each class
		for class_val in classes:
			p = [row[-1] for row in group].count(class_val) / size
			score += p * p
		# weight the group score by its relative size
		gini += (1.0 - score) * (size / n_instances)
	return gini

# Select the best split point for a dataset
def get_split(dataset):
	class_values = list(set(row[-1] for row in dataset))
	b_index, b_value, b_score, b_groups = 999, 999, 999, None
	for index in range(len(dataset[0])-1):
		for row in dataset:
			groups = test_split(index, row[index], dataset)
			gini = gini_index(groups, class_values)
			print('X%d < %.3f Gini=%.3f' % ((index+1), row[index], gini))
			if gini < b_score:
				b_index, b_value, b_score, b_groups = index, row[index], gini, groups
	return {'index':b_index, 'value':b_value, 'groups':b_groups}

dataset = [[2.771244718,1.784783929,0],
	[1.728571309,1.169761413,0],
	[3.678319846,2.81281357,0],
	[3.961043357,2.61995032,0],
	[2.999208922,2.209014212,0],
	[7.497545867,3.162953546,1],
	[9.00220326,3.339047188,1],
	[7.444542326,0.476683375,1],
	[10.12493903,3.234550982,1],
	[6.642287351,3.319983761,1]]
split = get_split(dataset)
print('Split: [X%d < %.3f]' % ((split['index']+1), split['value']))

La fonction get_split() a été modifiée pour imprimer chaque point de partage et son index Gini tel qu'il a été évalué.

L'exécution de l'exemple imprime tous les scores de Gini, puis imprime le score de la meilleure répartition dans l'ensemble de données de X1 < 6,642 avec un indice de Gini de 0,0 ou une répartition parfaite.

X1 < 2.771 Gini=0.444
X1 < 1.729 Gini=0.500
X1 < 3.678 Gini=0.286
X1 < 3.961 Gini=0.167
X1 < 2.999 Gini=0.375
X1 < 7.498 Gini=0.286
X1 < 9.002 Gini=0.375
X1 < 7.445 Gini=0.167
X1 < 10.125 Gini=0.444
X1 < 6.642 Gini=0.000
X2 < 1.785 Gini=0.500
X2 < 1.170 Gini=0.444
X2 < 2.813 Gini=0.320
X2 < 2.620 Gini=0.417
X2 < 2.209 Gini=0.476
X2 < 3.163 Gini=0.167
X2 < 3.339 Gini=0.444
X2 < 0.477 Gini=0.500
X2 < 3.235 Gini=0.286
X2 < 3.320 Gini=0.375
Split: [X1 < 6.642]

Maintenant que nous savons comment trouver les meilleurs points de partage dans un ensemble de données ou une liste de lignes, voyons comment nous pouvons l'utiliser pour créer un arbre de décision.

3. Construisez un arbre

Créer le nœud racine de l’arborescence est simple.

Nous appelons la fonction get_split() ci-dessus en utilisant l'ensemble de données.

Ajouter plus de nœuds à notre arbre est plus intéressant.

La construction d'un arbre peut être divisée en 3 parties principales :

  1. Nœuds terminaux.
  2. Fractionnement récursif.
  3. Construire un arbre.

3.1. Nœuds terminaux

Nous devons décider quand arrêter de faire pousser un arbre.

Nous pouvons le faire en utilisant la profondeur et le nombre de lignes dont le nœud est responsable dans l'ensemble de données de formation.

  • Profondeur maximale de l'arbre. Il s'agit du nombre maximum de nœuds à partir du nœud racine de l'arborescence. Une fois qu'une profondeur maximale de l'arbre est atteinte, nous devons arrêter de diviser en ajoutant de nouveaux nœuds. Les arbres plus profonds sont plus complexes et sont plus susceptibles de surajuster les données d'entraînement.
  • Enregistrements de nœuds minimum. Il s'agit du nombre minimum de modèles de formation dont un nœud donné est responsable. Une fois atteint ou inférieur à ce minimum, nous devons arrêter de diviser et d'ajouter de nouveaux nœuds. Les nœuds qui représentent trop peu de modèles de formation sont censés être trop spécifiques et susceptibles de surajuster les données de formation.

Ces deux approches seront des arguments spécifiés par l'utilisateur pour notre procédure de construction d'arborescence.

Il y a encore une condition. Il est possible de choisir une répartition dans laquelle toutes les lignes appartiennent à un seul groupe. Dans ce cas, nous ne pourrons pas continuer à diviser et à ajouter des nœuds enfants car nous n'aurons aucun enregistrement à diviser d'un côté ou de l'autre.

Nous avons maintenant quelques idées sur le moment où arrêter la croissance de l’arbre. Lorsque nous arrêtons de croître à un point donné, ce nœud est appelé nœud terminal et est utilisé pour faire une prédiction finale.

Cela se fait en prenant le groupe de lignes attribué à ce nœud et en sélectionnant la valeur de classe la plus courante dans le groupe. Cela servira à faire des prédictions.

Vous trouverez ci-dessous une fonction nommée to_terminal() qui sélectionnera une valeur de classe pour un groupe de lignes. Il renvoie la valeur de sortie la plus courante dans une liste de lignes.

# Create a terminal node value
def to_terminal(group):
	outcomes = [row[-1] for row in group]
	return max(set(outcomes), key=outcomes.count)

3.2. Fractionnement récursif

Nous savons comment et quand créer des nœuds terminaux, nous pouvons maintenant construire notre arbre.

Construire un arbre de décision implique d'appeler encore et encore la fonction get_split() développée ci-dessus sur les groupes créés pour chaque nœud.

Les nouveaux nœuds ajoutés à un nœud existant sont appelés nœuds enfants. Un nœud peut avoir zéro enfant (un nœud terminal), un enfant (un côté fait directement une prédiction) ou deux nœuds enfants. Nous désignerons les nœuds enfants comme gauche et droite dans la représentation dictionnaire d'un nœud donné.

Une fois qu'un nœud est créé, nous pouvons créer des nœuds enfants de manière récursive sur chaque groupe de données de la division en appelant à nouveau la même fonction.

Vous trouverez ci-dessous une fonction qui implémente cette procédure récursive. Il prend un nœud comme argument ainsi que la profondeur maximale, le nombre minimum de motifs dans un nœud et la profondeur actuelle d'un nœud.

Vous pouvez imaginer comment cela pourrait être appelé pour la première fois passer le nœud racine et la profondeur de 1. Cette fonction est mieux expliquée en étapes :

  1. Premièrement, les deux groupes de données divisés par le nœud sont extraits pour être utilisés et supprimés du nœud. Au fur et à mesure que nous travaillons sur ces groupes, le nœud n'a plus besoin d'accéder à ces données.
  2. Ensuite, nous vérifions si le groupe de lignes gauche ou droit est vide et si c'est le cas, nous créons un nœud terminal en utilisant les enregistrements dont nous disposons.
  3. On vérifie ensuite si on a atteint notre profondeur maximale et si c'est le cas on crée un nœud terminal.
  4. Nous traitons ensuite l'enfant de gauche, en créant un nœud terminal si le groupe de lignes est trop petit, sinon nous créons et ajoutons le nœud de gauche d'abord en profondeur jusqu'à ce que le bas de l'arbre soit atteint sur cette branche.
  5. Le côté droit est ensuite traité de la même manière, à mesure que nous remontons l'arbre construit jusqu'à la racine.
# Create child splits for a node or make terminal
def split(node, max_depth, min_size, depth):
	left, right = node['groups']
	del(node['groups'])
	# check for a no split
	if not left or not right:
		node['left'] = node['right'] = to_terminal(left + right)
		return
	# check for max depth
	if depth >= max_depth:
		node['left'], node['right'] = to_terminal(left), to_terminal(right)
		return
	# process left child
	if len(left) <= min_size:
		node['left'] = to_terminal(left)
	else:
		node['left'] = get_split(left)
		split(node['left'], max_depth, min_size, depth+1)
	# process right child
	if len(right) <= min_size:
		node['right'] = to_terminal(right)
	else:
		node['right'] = get_split(right)
		split(node['right'], max_depth, min_size, depth+1)

3.3. Construire un arbre

Nous pouvons maintenant rassembler toutes les pièces.

Construire l'arborescence implique de créer le nœud racine et d'appeler la fonction split() qui s'appelle ensuite de manière récursive pour construire l'arborescence entière.

Ci-dessous se trouve la petite fonction build_tree() qui implémente cette procédure.

# Build a decision tree
def build_tree(train, max_depth, min_size):
	root = get_split(train)
	split(root, max_depth, min_size, 1)
	return root

Nous pouvons tester toute cette procédure en utilisant le petit ensemble de données que nous avons créé ci-dessus.

Vous trouverez ci-dessous l'exemple complet.

Une petite fonction print_tree() est également incluse qui imprime de manière récursive les nœuds de l'arbre de décision avec une ligne par nœud. Bien qu’il ne soit pas aussi frappant qu’un véritable arbre de décision, il donne une idée de la structure arborescente et des décisions prises tout au long.

# Split a dataset based on an attribute and an attribute value
def test_split(index, value, dataset):
	left, right = list(), list()
	for row in dataset:
		if row[index] < value:
			left.append(row)
		else:
			right.append(row)
	return left, right

# Calculate the Gini index for a split dataset
def gini_index(groups, classes):
	# count all samples at split point
	n_instances = float(sum([len(group) for group in groups]))
	# sum weighted Gini index for each group
	gini = 0.0
	for group in groups:
		size = float(len(group))
		# avoid divide by zero
		if size == 0:
			continue
		score = 0.0
		# score the group based on the score for each class
		for class_val in classes:
			p = [row[-1] for row in group].count(class_val) / size
			score += p * p
		# weight the group score by its relative size
		gini += (1.0 - score) * (size / n_instances)
	return gini

# Select the best split point for a dataset
def get_split(dataset):
	class_values = list(set(row[-1] for row in dataset))
	b_index, b_value, b_score, b_groups = 999, 999, 999, None
	for index in range(len(dataset[0])-1):
		for row in dataset:
			groups = test_split(index, row[index], dataset)
			gini = gini_index(groups, class_values)
			if gini < b_score:
				b_index, b_value, b_score, b_groups = index, row[index], gini, groups
	return {'index':b_index, 'value':b_value, 'groups':b_groups}

# Create a terminal node value
def to_terminal(group):
	outcomes = [row[-1] for row in group]
	return max(set(outcomes), key=outcomes.count)

# Create child splits for a node or make terminal
def split(node, max_depth, min_size, depth):
	left, right = node['groups']
	del(node['groups'])
	# check for a no split
	if not left or not right:
		node['left'] = node['right'] = to_terminal(left + right)
		return
	# check for max depth
	if depth >= max_depth:
		node['left'], node['right'] = to_terminal(left), to_terminal(right)
		return
	# process left child
	if len(left) <= min_size:
		node['left'] = to_terminal(left)
	else:
		node['left'] = get_split(left)
		split(node['left'], max_depth, min_size, depth+1)
	# process right child
	if len(right) <= min_size:
		node['right'] = to_terminal(right)
	else:
		node['right'] = get_split(right)
		split(node['right'], max_depth, min_size, depth+1)

# Build a decision tree
def build_tree(train, max_depth, min_size):
	root = get_split(train)
	split(root, max_depth, min_size, 1)
	return root

# Print a decision tree
def print_tree(node, depth=0):
	if isinstance(node, dict):
		print('%s[X%d < %.3f]' % ((depth*' ', (node['index']+1), node['value'])))
		print_tree(node['left'], depth+1)
		print_tree(node['right'], depth+1)
	else:
		print('%s[%s]' % ((depth*' ', node)))

dataset = [[2.771244718,1.784783929,0],
	[1.728571309,1.169761413,0],
	[3.678319846,2.81281357,0],
	[3.961043357,2.61995032,0],
	[2.999208922,2.209014212,0],
	[7.497545867,3.162953546,1],
	[9.00220326,3.339047188,1],
	[7.444542326,0.476683375,1],
	[10.12493903,3.234550982,1],
	[6.642287351,3.319983761,1]]
tree = build_tree(dataset, 1, 1)
print_tree(tree)

Nous pouvons faire varier l'argument de profondeur maximale au fur et à mesure que nous exécutons cet exemple et voyons l'effet sur l'arborescence imprimée.

Avec une profondeur maximale de 1 (le deuxième paramètre de l'appel à la fonction build_tree()), nous pouvons voir que l'arbre utilise le fractionnement parfait que nous avons découvert dans la section précédente. Il s’agit d’un arbre à un nœud, également appelé souche de décision.

[X1 < 6.642]
 [0]
 [1]

En augmentant la profondeur maximale à 2, nous forçons l'arbre à faire des divisions même lorsqu'aucune n'est requise. L'attribut X1 est ensuite utilisé à nouveau par les enfants gauche et droit du nœud racine pour diviser le mélange déjà parfait de classes.

[X1 < 6.642]
 [X1 < 2.771]
  [0]
  [0]
 [X1 < 7.498]
  [1]
  [1]

Enfin, et de manière perverse, on peut forcer un niveau supplémentaire de splits avec une profondeur maximale de 3.

[X1 < 6.642]
 [X1 < 2.771]
  [0]
  [X1 < 2.771]
   [0]
   [0]
 [X1 < 7.498]
  [X1 < 7.445]
   [1]
   [1]
  [X1 < 7.498]
   [1]
   [1]

Ces tests montrent qu'il existe une grande opportunité d'affiner la mise en œuvre pour éviter des scissions inutiles. Ceci est laissé comme une extension.

Maintenant que nous pouvons créer un arbre de décision, voyons comment nous pouvons l’utiliser pour faire des prédictions sur de nouvelles données.

4. Faites une prédiction

Faire des prédictions avec un arbre de décision implique de naviguer dans l'arbre avec la ligne de données spécifiquement fournie.

Encore une fois, nous pouvons implémenter cela en utilisant une fonction récursive, où la même routine de prédiction est appelée à nouveau avec les nœuds enfants gauche ou droit, en fonction de la manière dont la division affecte les données fournies.

Il faut vérifier si un nœud enfant est soit une valeur terminale à renvoyer comme prédiction, soit s'il s'agit d'un nœud dictionnaire contenant un autre niveau de l'arbre à considérer.

Vous trouverez ci-dessous la fonction predict() qui implémente cette procédure. Vous pouvez voir comment l'index et la valeur dans un nœud donné

Vous pouvez voir comment l'index et la valeur d'un nœud donné sont utilisés pour évaluer si la ligne de données fournies se situe à gauche ou à droite de la division.

# Make a prediction with a decision tree
def predict(node, row):
	if row[node['index']] < node['value']:
		if isinstance(node['left'], dict):
			return predict(node['left'], row)
		else:
			return node['left']
	else:
		if isinstance(node['right'], dict):
			return predict(node['right'], row)
		else:
			return node['right']

Nous pouvons utiliser notre ensemble de données artificiel pour tester cette fonction. Vous trouverez ci-dessous un exemple qui utilise un arbre de décision codé en dur avec un seul nœud qui divise au mieux les données (une souche de décision).

L'exemple effectue une prédiction pour chaque ligne de l'ensemble de données.

# Make a prediction with a decision tree
def predict(node, row):
	if row[node['index']] < node['value']:
		if isinstance(node['left'], dict):
			return predict(node['left'], row)
		else:
			return node['left']
	else:
		if isinstance(node['right'], dict):
			return predict(node['right'], row)
		else:
			return node['right']

dataset = [[2.771244718,1.784783929,0],
	[1.728571309,1.169761413,0],
	[3.678319846,2.81281357,0],
	[3.961043357,2.61995032,0],
	[2.999208922,2.209014212,0],
	[7.497545867,3.162953546,1],
	[9.00220326,3.339047188,1],
	[7.444542326,0.476683375,1],
	[10.12493903,3.234550982,1],
	[6.642287351,3.319983761,1]]

#  predict with a stump
stump = {'index': 0, 'right': 1, 'value': 6.642287351, 'left': 0}
for row in dataset:
	prediction = predict(stump, row)
	print('Expected=%d, Got=%d' % (row[-1], prediction))

L’exécution de l’exemple imprime la prédiction correcte pour chaque ligne, comme prévu.

Expected=0, Got=0
Expected=0, Got=0
Expected=0, Got=0
Expected=0, Got=0
Expected=0, Got=0
Expected=1, Got=1
Expected=1, Got=1
Expected=1, Got=1
Expected=1, Got=1
Expected=1, Got=1

Nous savons désormais comment créer un arbre de décision et l'utiliser pour faire des prédictions. Maintenant, appliquons-le à un ensemble de données réel.

5. Étude de cas sur les billets de banque

Cette section applique l'algorithme CART à l'ensemble de données Bank Note.

La première étape consiste à charger l'ensemble de données et à convertir les données chargées en nombres que nous pouvons utiliser pour calculer les points de partage. Pour cela, nous utiliserons la fonction d'assistance load_csv() pour charger le fichier et str_column_to_float() pour convertir les numéros de chaîne en flottants.

Nous évaluerons l'algorithme en utilisant une validation croisée k fois avec 5 fois. Cela signifie que 1 372/5=274,4, soit un peu plus de 270 enregistrements, seront utilisés dans chaque pli. Nous utiliserons les fonctions d'assistance evaluate_algorithm() pour évaluer l'algorithme avec validation croisée et accuracy_metric() pour calculer l'exactitude des prédictions.

Une nouvelle fonction nommée decision_tree() a été développée pour gérer l'application de l'algorithme CART, en créant d'abord l'arbre à partir de l'ensemble de données d'entraînement, puis en utilisant l'arbre pour faire des prédictions sur un ensemble de données de test.

L’exemple complet est répertorié ci-dessous.

# CART on the Bank Note dataset
from random import seed
from random import randrange
from csv import reader

# Load a CSV file
def load_csv(filename):
	file = open(filename, "rt")
	lines = reader(file)
	dataset = list(lines)
	return dataset

# Convert string column to float
def str_column_to_float(dataset, column):
	for row in dataset:
		row[column] = float(row[column].strip())

# Split a dataset into k folds
def cross_validation_split(dataset, n_folds):
	dataset_split = list()
	dataset_copy = list(dataset)
	fold_size = int(len(dataset) / n_folds)
	for i in range(n_folds):
		fold = list()
		while len(fold) < fold_size:
			index = randrange(len(dataset_copy))
			fold.append(dataset_copy.pop(index))
		dataset_split.append(fold)
	return dataset_split

# Calculate accuracy percentage
def accuracy_metric(actual, predicted):
	correct = 0
	for i in range(len(actual)):
		if actual[i] == predicted[i]:
			correct += 1
	return correct / float(len(actual)) * 100.0

# Evaluate an algorithm using a cross validation split
def evaluate_algorithm(dataset, algorithm, n_folds, *args):
	folds = cross_validation_split(dataset, n_folds)
	scores = list()
	for fold in folds:
		train_set = list(folds)
		train_set.remove(fold)
		train_set = sum(train_set, [])
		test_set = list()
		for row in fold:
			row_copy = list(row)
			test_set.append(row_copy)
			row_copy[-1] = None
		predicted = algorithm(train_set, test_set, *args)
		actual = [row[-1] for row in fold]
		accuracy = accuracy_metric(actual, predicted)
		scores.append(accuracy)
	return scores

# Split a dataset based on an attribute and an attribute value
def test_split(index, value, dataset):
	left, right = list(), list()
	for row in dataset:
		if row[index] < value:
			left.append(row)
		else:
			right.append(row)
	return left, right

# Calculate the Gini index for a split dataset
def gini_index(groups, classes):
	# count all samples at split point
	n_instances = float(sum([len(group) for group in groups]))
	# sum weighted Gini index for each group
	gini = 0.0
	for group in groups:
		size = float(len(group))
		# avoid divide by zero
		if size == 0:
			continue
		score = 0.0
		# score the group based on the score for each class
		for class_val in classes:
			p = [row[-1] for row in group].count(class_val) / size
			score += p * p
		# weight the group score by its relative size
		gini += (1.0 - score) * (size / n_instances)
	return gini

# Select the best split point for a dataset
def get_split(dataset):
	class_values = list(set(row[-1] for row in dataset))
	b_index, b_value, b_score, b_groups = 999, 999, 999, None
	for index in range(len(dataset[0])-1):
		for row in dataset:
			groups = test_split(index, row[index], dataset)
			gini = gini_index(groups, class_values)
			if gini < b_score:
				b_index, b_value, b_score, b_groups = index, row[index], gini, groups
	return {'index':b_index, 'value':b_value, 'groups':b_groups}

# Create a terminal node value
def to_terminal(group):
	outcomes = [row[-1] for row in group]
	return max(set(outcomes), key=outcomes.count)

# Create child splits for a node or make terminal
def split(node, max_depth, min_size, depth):
	left, right = node['groups']
	del(node['groups'])
	# check for a no split
	if not left or not right:
		node['left'] = node['right'] = to_terminal(left + right)
		return
	# check for max depth
	if depth >= max_depth:
		node['left'], node['right'] = to_terminal(left), to_terminal(right)
		return
	# process left child
	if len(left) <= min_size:
		node['left'] = to_terminal(left)
	else:
		node['left'] = get_split(left)
		split(node['left'], max_depth, min_size, depth+1)
	# process right child
	if len(right) <= min_size:
		node['right'] = to_terminal(right)
	else:
		node['right'] = get_split(right)
		split(node['right'], max_depth, min_size, depth+1)

# Build a decision tree
def build_tree(train, max_depth, min_size):
	root = get_split(train)
	split(root, max_depth, min_size, 1)
	return root

# Make a prediction with a decision tree
def predict(node, row):
	if row[node['index']] < node['value']:
		if isinstance(node['left'], dict):
			return predict(node['left'], row)
		else:
			return node['left']
	else:
		if isinstance(node['right'], dict):
			return predict(node['right'], row)
		else:
			return node['right']

# Classification and Regression Tree Algorithm
def decision_tree(train, test, max_depth, min_size):
	tree = build_tree(train, max_depth, min_size)
	predictions = list()
	for row in test:
		prediction = predict(tree, row)
		predictions.append(prediction)
	return(predictions)

# Test CART on Bank Note dataset
seed(1)
# load and prepare data
filename = 'data_banknote_authentication.csv'
dataset = load_csv(filename)
# convert string attributes to integers
for i in range(len(dataset[0])):
	str_column_to_float(dataset, i)
# evaluate algorithm
n_folds = 5
max_depth = 5
min_size = 10
scores = evaluate_algorithm(dataset, decision_tree, n_folds, max_depth, min_size)
print('Scores: %s' % scores)
print('Mean Accuracy: %.3f%%' % (sum(scores)/float(len(scores))))

L'exemple utilise la profondeur maximale de l'arbre de 5 couches et le nombre minimum de lignes par nœud de 10. Ces paramètres de CART ont été choisis avec un peu d'expérimentation, mais ne sont en aucun cas optimaux.

L’exécution de l’exemple imprime la précision moyenne de la classification sur chaque pli ainsi que les performances moyennes sur tous les plis.

Vous pouvez voir que CART et la configuration choisie ont atteint une précision de classification moyenne d'environ 97 %, ce qui est nettement meilleur que l'algorithme Zero Rule qui a atteint une précision de 50 %.

Scores: [96.35036496350365, 97.08029197080292, 97.44525547445255, 98.17518248175182, 97.44525547445255]
Mean Accuracy: 97.299%

Rallonges

Cette section répertorie les extensions de ce didacticiel que vous souhaiterez peut-être explorer.

  • Réglage de l'algorithme. L'application de CART à l'ensemble de données sur les billets de banque n'a pas été optimisée. Expérimentez avec différentes valeurs de paramètres et voyez si vous pouvez obtenir de meilleures performances.
  • Entropie croisée. Une autre fonction de coût pour évaluer les divisions est l’entropie croisée (logloss). Vous pouvez implémenter et expérimenter cette fonction de coût alternative.
  • Élagage des arbres. Une technique importante pour réduire le surajustement de l'ensemble de données d'entraînement consiste à élaguer les arbres. Étudier et mettre en œuvre des méthodes d’élagage des arbres.
  • Ensemble de données catégorielles. L'exemple a été conçu pour les données d'entrée avec des attributs d'entrée numériques ou ordinaux, pour expérimenter des données d'entrée catégorielles et des divisions qui peuvent utiliser l'égalité au lieu du classement.
  • Régression. Adaptez l'arbre pour la régression en utilisant une fonction de coût et une méthode différentes pour créer des nœuds terminaux.
  • Plus d'ensembles de données. Appliquez l'algorithme à davantage d'ensembles de données sur le référentiel UCI Machine Learning.

Avez-vous exploré l'une de ces extensions ?
Partagez vos expériences dans les commentaires ci-dessous.

Revoir

Dans ce didacticiel, vous avez découvert comment implémenter l'algorithme d'arbre de décision à partir de zéro avec Python.

Concrètement, vous avez appris :

  • Comment sélectionner et évaluer des points de partage dans un ensemble de données d'entraînement.
  • Comment construire de manière récursive un arbre de décision à partir de plusieurs divisions.
  • Comment appliquer l'algorithme CART à un problème de modélisation prédictive de classification du monde réel.

Avez-vous des questions ?
Posez vos questions dans les commentaires ci-dessous et je ferai de mon mieux pour y répondre.

Articles connexes