static int handle_task_switch(struct kvm_vcpu *vcpu, struct kvm_run *kvm_run)
 {
+       struct vcpu_vmx *vmx = to_vmx(vcpu);
        unsigned long exit_qualification;
        u16 tss_selector;
        int reason;
        exit_qualification = vmcs_readl(EXIT_QUALIFICATION);
 
        reason = (u32)exit_qualification >> 30;
+       if (reason == TASK_SWITCH_GATE && vmx->vcpu.arch.nmi_injected &&
+           (vmx->idt_vectoring_info & VECTORING_INFO_VALID_MASK) &&
+           (vmx->idt_vectoring_info & VECTORING_INFO_TYPE_MASK)
+           == INTR_TYPE_NMI_INTR) {
+               vcpu->arch.nmi_injected = false;
+               if (cpu_has_virtual_nmis())
+                       vmcs_set_bits(GUEST_INTERRUPTIBILITY_INFO,
+                                     GUEST_INTR_STATE_NMI);
+       }
        tss_selector = exit_qualification;
 
        return kvm_task_switch(vcpu, tss_selector, reason);
 
        if ((vectoring_info & VECTORING_INFO_VALID_MASK) &&
                        (exit_reason != EXIT_REASON_EXCEPTION_NMI &&
-                       exit_reason != EXIT_REASON_EPT_VIOLATION))
-               printk(KERN_WARNING "%s: unexpected, valid vectoring info and "
-                      "exit reason is 0x%x\n", __func__, exit_reason);
+                       exit_reason != EXIT_REASON_EPT_VIOLATION &&
+                       exit_reason != EXIT_REASON_TASK_SWITCH))
+               printk(KERN_WARNING "%s: unexpected, valid vectoring info "
+                      "(0x%x) and exit reason is 0x%x\n",
+                      __func__, vectoring_info, exit_reason);
        if (exit_reason < kvm_vmx_max_exit_handlers
            && kvm_vmx_exit_handlers[exit_reason])
                return kvm_vmx_exit_handlers[exit_reason](vcpu, kvm_run);