Un arbre de décision avec R

decision tree avec RAprès mon article sur le fonctionnement de l’algorithme Decision Tree (disponible ici), voici une manière de le mettre en oeuvre sous R.

 

Nous allons implémenter un arbre CART. Dans mon précédent article j’utilisais l’entropie comme critère de sélection, ici avec l’arbre CART, c’est l’indice de Gini qui est utilisé.

 

Chargement des librairies

On commence par charger 2 librairies qui nous permettront de créer l’arbre de décision et de le représenter

library(rpart)# Pour l’arbre de décision
library(rpart.plot) # Pour la représentation de l’arbre de décision

Nous allons utiliser le dataset ptitanic qui est disponible avec la librairie rpart. Le fichier contient 1309 individus et 6 variables dont survived qui indique si l’individu a survécu ou non au Titanic.

Chargement et préparation des données

Chargement des données
data(ptitanic)

#Description des données
summary(ptitanic)

 

Voilà le résultat, on a la fréquence des modalités pour les variables qualitatives et quelques chiffres clés pour les variables quantitatives

##  pclass        survived       sex           age              sibsp       
##  1st:323   died    :809   female:466   Min.   : 0.1667   Min.   :0.0000  
##  2nd:277   survived:500   male  :843   1st Qu.:21.0000   1st Qu.:0.0000  
##  3rd:709                               Median :28.0000   Median :0.0000  
##                                        Mean   :29.8811   Mean   :0.4989  
##                                        3rd Qu.:39.0000   3rd Qu.:1.0000  
##                                        Max.   :80.0000   Max.   :8.0000  
##                                        NA's   :263                       
##      parch      
##  Min.   :0.000  
##  1st Qu.:0.000  
##  Median :0.000  
##  Mean   :0.385  
##  3rd Qu.:0.000  
##  Max.   :9.000  
## 

Comme pour tout modèle, nous avons besoin de construire l’arbre de décision sur un dataset d’apprentissage et de la tester ensuite sur un dataset de test. La librairie rpart inclus de la validation croisée mais il est toujours préférable de calculer la performance sur un échantillon qui n’est pas impliqué dans le calcul. On sépare donc nos données en 2 échantillons.

#Création d’un dataset d’apprentissage et d’un dataset de validation
nb_lignes <- floor((nrow(ptitanic)*0.75)) #Nombre de lignes de l’échantillon d’apprentissage : 75% du dataset
ptitanic <- ptitanic[sample(nrow(ptitanic)), ] #Ajout de numéros de lignes
ptitanic.train <- ptitanic[1:nb_lignes, ] #Echantillon d’apprentissage
ptitanic.test <- ptitanic[(nb_lignes+1):nrow(ptitanic), ] #Echantillon de test

Apprentissage

Nous allons maintenant construire l’arbre et l’élaguer

#Construction de l’arbre
ptitanic.Tree <- rpart(survived~.,data=ptitanic.train,method=« class », control=rpart.control(minsplit=5,cp=0))

#Affichage du résultat
plot(ptitanic.Tree, uniform=TRUE, branch=0.5, margin=0.1)
text(ptitanic.Tree, all=FALSE, use.n=TRUE)

 

Je ne vous montre pas l’arbre intermédiaire mais on voit qu’il est très développé, d’où l’intérêt de l’élagage pour le simplifier et pour éviter le surapprentissage.

#On cherche à minimiser l’erreur pour définir le niveau d’élagage
plotcp(ptitanic.Tree)

Rplot01
Choix du cp optimal pour un arbre de décision

Le graphique ci-dessus affiche le taux de mauvais classement en fonction de la taille de l’arbre. On cherche à minimiser l’erreur.

#Affichage du cp optimal
print(ptitanic.Tree$cptable[which.min(ptitanic.Tree$cptable[,4]),1])

## [1] 0.0028

#Elagage de l’arbre avec le cp optimal
ptitanic.Tree_Opt <- prune(ptitanic.Tree,cp=ptitanic.Tree$cptable[which.min(ptitanic.Tree$cptable[,4]),1])

#Représentation graphique de l’arbre optimal
prp(ptitanic.Tree_Opt,extra=1)

Voici notre arbre. On a utilisé la représentation graphique avec la fonction prp. On aurait aussi pu utiliser la fonction plot comme nous l’avions fait au début.

Rplot

Ici, pour chaque feuille, R indique la classe prédite, le nombre d’individus de la classe prédite à gauche et le nombre d’individus de l’autre (ou des autres) classes à droite. Par exemple, si on prend la première feuille à gauche qui correspond aux hommes de plus de 10 ans, la feuille contient 660 passagers décédés et 136 survivants. Et cette sous-population est classée dans « Died ».

Validation

On teste et on valide les résultats avec l’échantillon de test

#Prédiction du modèle sur les données de test
ptitanic.test_Predict<-predict(ptitanic.Tree_Opt,newdata=ptitanic.test, type=« class »)

#Matrice de confusion
mc<-table(ptitanic.test$survived,ptitanic.test_Predict)
print(mc)

##           ptitanic.test_Predict
##            died survived
##   died      187       15
##   survived   43       83

#Erreur de classement
erreur.classement<-1.0-(mc[1,1]+mc[2,2])/sum(mc)
print(erreur.classement)

## [1] 0.1768293

#Taux de prédiction
prediction=mc[2,2]/sum(mc[2,])
print(prediction)

## [1] 0.6587302

Résultats

Et pour finir, si les résultats conviennent, on peut afficher les règles de décision et commencer à les exploiter

#Affichage des règles de construction de l’arbre
print(ptitanic.Tree_Opt)


  1) root 981 376 died (0.61671764 0.38328236)  
    2) sex=male 622 114 died (0.81672026 0.18327974)  
      4) age>=9.5 588  96 died (0.83673469 0.16326531) *
      5) age< 9.5 34  16 survived (0.47058824 0.52941176)          
       10) sibsp>=2.5 14   1 died (0.92857143 0.07142857) *
       11) sibsp< 2.5 20   3 survived (0.15000000 0.85000000) *     
    3) sex=female 359  97 survived (0.27019499 0.72980501)         
      6) pclass=3rd 168  84 died (0.50000000 0.50000000)          
       12) sibsp>=2.5 16   3 died (0.81250000 0.18750000) *
       13) sibsp< 2.5 152  71 survived (0.46710526 0.53289474)            
         26) age>=16.5 125  62 died (0.50400000 0.49600000)  
           52) parch>=3.5 6   1 died (0.83333333 0.16666667) *
           53) parch< 3.5 119  58 survived (0.48739496 0.51260504)               
            106) age>=27.5 31  11 died (0.64516129 0.35483871) *
            107) age< 27.5 88  38 survived (0.43181818 0.56818182) *
         27) age< 16.5 27   8 survived (0.29629630 0.70370370) *
       7) pclass=1st,2nd 191  13 survived (0.06806283 0.93193717) *

Chaque ligne correspond à un noeud ou à une feuille de l’arbre. On commence par la racine qui contient les 981 individus de l’échantillon d’apprentissage.

A l’étape suivante l’arbre découpe la population en fonction de la variable « Sex » et créé les noeuds 2) et 3). Le noeud numéro 3 contient les femmes : 359 individus. Leur taux de survie est de l’ordre de 71% et est donc supérieur au taux moyen de 38%. Ce noeud a donc été labellisé comme « Survived ».  R indique ensuite le nombre d’individus mal classés dans ce noeud : 97 (femmes qui sont décédées). Nous avons ensuite le taux de mauvais classement : 27% et le taux de bon classement : 73%.

On continue ainsi de suite le découpage de l’arbre jusqu’aux feuilles qui sont identifiées par une étoile en fin de ligne.

Bien sûr il est interdit d’utiliser ce type de représentations dans les présentations de résultats, c’est trop technique. Il vaut mieux digérer toutes ces informations et les restituer de manière plus visuelle.

Pour aller plus loin vous pouvez aussi faire un Random Forest avec R.

 

12 réflexions sur “Un arbre de décision avec R”

  1. Thibaud MONTHILLER

    Bonjour dans le cadre de mes études je tente d’adapter votre code à mon jeu de données. Néanmoins lorsque je trace l’arbre élagué le message d’erreur: ‘plot is not a tree just a root’ apparaît. Sauriez-vous m’expliquer sa signification?

    1. Bonjour Thibaud,
      R renvoie ce type de message quand il n’a pas pu construire l’arbre, il dit qu’il ne peut pas le tracer parce qu’il ne s’agit pas d’un arbre seulement d’un noeud. Il doit y avoir un soucis au niveau de vos données. Pour essayer de trouver ce qui pose problème vous pouvez essayer de lancer un modèle plus simple avec un nombre de variables limité et voir si ça passe mieux. Vous pouvez aussi vérifier votre variable cible à prédire.

  2. Bonjour !
    Je suis un nouveau dans le monde de Data_science. Je vous remercie pour vos superbes explications sur la page du travail et j’aime bien.
    1) Je voudrais savoir ce que représente « survived » que vous utilisez dans votre code . Est ce que c’est une partie de la base de données qui est déjà la ou on doit la créer?
    2) Pourquoi vous utilisez « ~. » en après, quel est sont rôle?

    Merci pour votre réponse !

    1. Bonjour Augusma,
      « survived » est une colonne présente dans les données que j’ai utilisé. Il s’agit du dataset du titanic. Il contient la liste des passagers du titanic avec leurs caractéristiques et indique s’ils ont survécu ou non au naufrage. C’est justement cette colonne survived. C’est aussi ce que l’on cherche à prédire avec l’arbre de décision. Quelle est la chance de survie des passagers du titanic en fonction de leurs caractéristiques.

      Pour le 2e point : quand on écrit survived~. on demande à R de prédire la colonne survived en fonction de toutes les autres colonnes du dataset. C’est le rôle du symbole ~

  3. Bonjour Marie-Jeanne,
    merci beaucoup pour ce tutoriel très pédagogique. Vous avez utilisé une séparation de données en jeu d’apprentissage et de test. Dans le cas où on souhaite faire une leave-one-out. Comment doit-on procéder.
    Merci d’avance
    Cordialement
    Sidy

  4. Bonsoir,
    Merci pour ce cours, c’est très intéressant. Mais, j’ai un petit souci quand je sais de l adapter au niveau de ma base de données. R me renvoie ce message « Error in model.frame.default(formula = class ~ ., data = data_train, na.action = function (x) :
    object is not a matrix »
    Je ne comprends pas et pour tant tout est bon.

    Merci de m’apporter quelques éléments de réponses

    1. Bonjour,

      Est ce que vos données sont dans un dataframe?
      Dans cet exemple j’utilise les données Titanic qui sont déjà formatées dans R. Mais si vous utilisez un dataset importé dans R il faudra le transformer en dataframe avec la fonction data.frame

      Ca peut aussi être d’autres problèmes mais vu l’erreur renvoyée par R il y a des chances que ce soit ça.

      N’hesitez pas a me dire si ça ne fonctionne toujours pas, on va trouver 😉

Laisser un commentaire

Votre adresse e-mail ne sera pas publiée. Les champs obligatoires sont indiqués avec *