Category Archives: kmeans

How to make Custom estimator class and custom scorer to do cross validation using sklearn api on your custom model

I made a combined weak classifier model, needed a custom estimator and custom scorer. I went through a few stack overflow articles however none actually targeted specifically for cross validation in sklearn.

Then I figured I would try to implement baseestimator class, and make my own scorer. It WORKED. :>

Therefore, I am posting instructions here on how to use it, hopefully its gonna be useful to you.


  1. Write your own estimator class, just make sure to implement base estimator (or extend I am not sure how this works in python but its similar. base estimator is like an interface or abstract class provides basic functionalities for estimator)
  2. Write your loss function or gain function, and then make your own scorer
  3. Use the sklearn api to do cross validation. Using whatever you have created in 1 and 2.

Code: Please read comments. Important.

#create a custom estimator class

#Keep in mind. This is just a simplified version. You can treat it as any other class, just make sure the signitures should stay same, or you should add default value to other parameters

from sklearn.base import BaseEstimator
class custom_classifier(BaseEstimator):
  from sklearn import tree
  from sklearn.cluster import KMeans
  import numpy as np
  from sklearn.cluster import KMeans
  #Kmeans clustering model
  __clusters = None
  #decision tree model
  __tree = None
  #x library.
  __X = None
  #y library.
  __y = None
  #columns selected from pandas dataframe.
  __columns = None

  def fit(self, X, y, **kwargs):

  def predict(self,X):
    result_kmeans = self.__clusters.predict(X)
    result_tree = self.__tree.predict(X)
    result = result_tree
    return np.array(result)

  def fit_kmeans(self,X,y):
    clusters = KMeans(n_clusters=4, random_state=0).fit(X)
    #the error center should have the lowest number of labels.(implementation not shown here)
    self.__clusters = clusters

  def fit_decisiontree(self,X,y):
    temp_tree = tree.DecisionTreeClassifier(criterion='entropy',max_depth=3),y)
    self.__tree = temp_tree

Now we have our class. We need to build hit/loss function:

#again, feel free to change any thing in the hit function. As long as the function signature remain the same.

def seg_tree_hit_func(ground_truth, predictions):
  total_hit = 0
  total_number = 0
  for i in xrange(len(predictions)):
    if predictions[i]==2:
      total_hit += (1-abs(ground_truth[i]-predictions[i]))
    print 'skipped: ',len(predictions)- total_number,'/',len(predictions),'instances'
  return total_hit/total_number if total_number!=0 else 0

Now we still need to build scorer.

from sklearn.metrics.scorer import make_scorer

#make our own scorer
score = make_scorer(seg_tree_hit_func, greater_is_better=True)

We have our scorer, our estimator, and so we can start doing cross-validation task:

#change the 7 to whatever fold validation you are running.

scores = cross_val_score(custom_classifier(), X, Y, cv=7, scoring=score)

There it is! You have your own scorer and estimator, and you can use sklearn api to plug it in anything from sklearn easily.


Hope this helps.

BIC and AIC in python using scipy.vq kmeans

def aic(data,distortion, clusterNumber):
import math
return distortion+2*clusterNumber*len(data[0])

Quote from ppl using R:

To compute BIC, Add .5*k*d*log(n) (where k is the number of means, d is the length of a vector in your dataset, and n is the number of data points) to the standard k-means error function.

The standard k-means penalty is \sum_n (m_k(n)-x_n)^2, where m_k(n) is the mean associated with the nth data point. This penalty can be interpreted as a log probability, so BIC is perfectly valid.

BIC just adds an additional penalty term to the k-means error proportional to k.


def bic(data,distortion, clusterNumber):
import numpy as np
import math
if type(data)!= type(np.array([])):
print('invlaid data type in bic')
return 0
return distortion+0.5*math.log(data.size)*clusterNumber*len(data[0])