ret = NULL;
        }
 
-       if (ret && !try_module_get(ret->owner))
+       if (!ret)
+               goto out_unlock;
+
+       if (!try_module_get(ret->owner)) {
+               ret = NULL;
+               goto out_unlock;
+       }
+
+       if (ret->get_device && ret->get_device(ret)) {
+               module_put(ret->owner);
                ret = NULL;
+               goto out_unlock;
+       }
 
-       if (ret)
-               ret->usecount++;
+       ret->usecount++;
 
+out_unlock:
        mutex_unlock(&mtd_table_mutex);
        return ret;
 }
 
 struct mtd_info *get_mtd_device_nm(const char *name)
 {
-       int i;
-       struct mtd_info *mtd = ERR_PTR(-ENODEV);
+       int i, err = -ENODEV;
+       struct mtd_info *mtd = NULL;
 
        mutex_lock(&mtd_table_mutex);
 
                }
        }
 
-       if (i == MAX_MTD_DEVICES)
+       if (!mtd)
                goto out_unlock;
 
        if (!try_module_get(mtd->owner))
                goto out_unlock;
 
+       if (mtd->get_device) {
+               err = mtd->get_device(mtd);
+               if (err)
+                       goto out_put;
+       }
+
        mtd->usecount++;
+       mutex_unlock(&mtd_table_mutex);
+       return mtd;
 
+out_put:
+       module_put(mtd->owner);
 out_unlock:
        mutex_unlock(&mtd_table_mutex);
-       return mtd;
+       return ERR_PTR(err);
 }
 
 void put_mtd_device(struct mtd_info *mtd)
 
        mutex_lock(&mtd_table_mutex);
        c = --mtd->usecount;
+       if (mtd->put_device)
+               mtd->put_device(mtd);
        mutex_unlock(&mtd_table_mutex);
        BUG_ON(c < 0);
 
 
 
        struct module *owner;
        int usecount;
+
+       /* If the driver is something smart, like UBI, it may need to maintain
+        * its own reference counting. The below functions are only for driver.
+        * The driver may register its callbacks. These callbacks are not
+        * supposed to be called by MTD users */
+       int (*get_device) (struct mtd_info *mtd);
+       void (*put_device) (struct mtd_info *mtd);
 };