static void intel_mm_release(struct mmu_notifier *mn, struct mm_struct *mm)
{
struct intel_svm *svm = container_of(mn, struct intel_svm, notifier);
+ struct intel_svm_dev *sdev;
+ /* This might end up being called from exit_mmap(), *before* the page
+ * tables are cleared. And __mmu_notifier_release() will delete us from
+ * the list of notifiers so that our invalidate_range() callback doesn't
+ * get called when the page tables are cleared. So we need to protect
+ * against hardware accessing those page tables.
+ *
+ * We do it by clearing the entry in the PASID table and then flushing
+ * the IOTLB and the PASID table caches. This might upset hardware;
+ * perhaps we'll want to point the PASID to a dummy PGD (like the zero
+ * page) so that we end up taking a fault that the hardware really
+ * *has* to handle gracefully without affecting other processes.
+ */
svm->iommu->pasid_table[svm->pasid].val = 0;
+ wmb();
+
+ rcu_read_lock();
+ list_for_each_entry_rcu(sdev, &svm->devs, list) {
+ intel_flush_pasid_dev(svm, sdev, svm->pasid);
+ intel_flush_svm_range_dev(svm, sdev, 0, -1, 0, !svm->mm);
+ }
+ rcu_read_unlock();
- /* There's no need to do any flush because we can't get here if there
- * are any devices left anyway. */
- WARN_ON(!list_empty(&svm->devs));
}
static const struct mmu_notifier_ops intel_mmuops = {
goto out;
}
iommu->pasid_table[svm->pasid].val = (u64)__pa(mm->pgd) | 1;
- mm = NULL;
} else
iommu->pasid_table[svm->pasid].val = (u64)__pa(init_mm.pgd) | 1 | (1ULL << 11);
wmb();
kfree_rcu(sdev, rcu);
if (list_empty(&svm->devs)) {
- mmu_notifier_unregister(&svm->notifier, svm->mm);
idr_remove(&svm->iommu->pasid_idr, svm->pasid);
if (svm->mm)
- mmput(svm->mm);
+ mmu_notifier_unregister(&svm->notifier, svm->mm);
+
/* We mandate that no page faults may be outstanding
* for the PASID when intel_svm_unbind_mm() is called.
* If that is not obeyed, subtle errors will happen.
* any faults on kernel addresses. */
if (!svm->mm)
goto bad_req;
+ /* If the mm is already defunct, don't handle faults. */
+ if (!atomic_inc_not_zero(&svm->mm->mm_users))
+ goto bad_req;
down_read(&svm->mm->mmap_sem);
vma = find_extend_vma(svm->mm, address);
if (!vma || address < vma->vm_start)
result = QI_RESP_SUCCESS;
invalid:
up_read(&svm->mm->mmap_sem);
+ mmput(svm->mm);
bad_req:
/* Accounting for major/minor faults? */
rcu_read_lock();