7e1dafccb11e91e79137a0c3be6786ae9d06ed44
[firefly-linux-kernel-4.4.55.git] / drivers / infiniband / hw / usnic / usnic_uiom_interval_tree.c
1 #include <linux/init.h>
2 #include <linux/list.h>
3 #include <linux/slab.h>
4 #include <linux/list_sort.h>
5 #include <linux/version.h>
6
7 #include <linux/interval_tree_generic.h>
8 #include "usnic_uiom_interval_tree.h"
9
10 #define START(node) ((node)->start)
11 #define LAST(node) ((node)->last)
12
13 #define MAKE_NODE(node, start, end, ref_cnt, flags, err, err_out)       \
14                 do {                                                    \
15                         node = usnic_uiom_interval_node_alloc(start,    \
16                                         end, ref_cnt, flags);           \
17                                 if (!node) {                            \
18                                         err = -ENOMEM;                  \
19                                         goto err_out;                   \
20                                 }                                       \
21                 } while (0)
22
23 #define MARK_FOR_ADD(node, list) (list_add_tail(&node->link, list))
24
25 #define MAKE_NODE_AND_APPEND(node, start, end, ref_cnt, flags, err,     \
26                                 err_out, list)                          \
27                                 do {                                    \
28                                         MAKE_NODE(node, start, end,     \
29                                                 ref_cnt, flags, err,    \
30                                                 err_out);               \
31                                         MARK_FOR_ADD(node, list);       \
32                                 } while (0)
33
34 #define FLAGS_EQUAL(flags1, flags2, mask)                               \
35                         (((flags1) & (mask)) == ((flags2) & (mask)))
36
37 static struct usnic_uiom_interval_node*
38 usnic_uiom_interval_node_alloc(long int start, long int last, int ref_cnt,
39                                 int flags)
40 {
41         struct usnic_uiom_interval_node *interval = kzalloc(sizeof(*interval),
42                                                                 GFP_ATOMIC);
43         if (!interval)
44                 return NULL;
45
46         interval->start = start;
47         interval->last = last;
48         interval->flags = flags;
49         interval->ref_cnt = ref_cnt;
50
51         return interval;
52 }
53
54 static int interval_cmp(void *priv, struct list_head *a, struct list_head *b)
55 {
56         struct usnic_uiom_interval_node *node_a, *node_b;
57
58         node_a = list_entry(a, struct usnic_uiom_interval_node, link);
59         node_b = list_entry(b, struct usnic_uiom_interval_node, link);
60
61         /* long to int */
62         if (node_a->start < node_b->start)
63                 return -1;
64         else if (node_a->start > node_b->start)
65                 return 1;
66
67         return 0;
68 }
69
70 static void
71 find_intervals_intersection_sorted(struct rb_root *root, unsigned long start,
72                                         unsigned long last,
73                                         struct list_head *list)
74 {
75         struct usnic_uiom_interval_node *node;
76
77         INIT_LIST_HEAD(list);
78
79         for (node = usnic_uiom_interval_tree_iter_first(root, start, last);
80                 node;
81                 node = usnic_uiom_interval_tree_iter_next(node, start, last))
82                 list_add_tail(&node->link, list);
83
84         list_sort(NULL, list, interval_cmp);
85 }
86
87 int usnic_uiom_get_intervals_diff(unsigned long start, unsigned long last,
88                                         int flags, int flag_mask,
89                                         struct rb_root *root,
90                                         struct list_head *diff_set)
91 {
92         struct usnic_uiom_interval_node *interval, *tmp;
93         int err = 0;
94         long int pivot = start;
95         LIST_HEAD(intersection_set);
96
97         INIT_LIST_HEAD(diff_set);
98
99         find_intervals_intersection_sorted(root, start, last,
100                                                 &intersection_set);
101
102         list_for_each_entry(interval, &intersection_set, link) {
103                 if (pivot < interval->start) {
104                         MAKE_NODE_AND_APPEND(tmp, pivot, interval->start - 1,
105                                                 1, flags, err, err_out,
106                                                 diff_set);
107                         pivot = interval->start;
108                 }
109
110                 /*
111                  * Invariant: Set [start, pivot] is either in diff_set or root,
112                  * but not in both.
113                  */
114
115                 if (pivot > interval->last) {
116                         continue;
117                 } else if (pivot <= interval->last &&
118                                 FLAGS_EQUAL(interval->flags, flags,
119                                 flag_mask)) {
120                         pivot = interval->last + 1;
121                 }
122         }
123
124         if (pivot <= last)
125                 MAKE_NODE_AND_APPEND(tmp, pivot, last, 1, flags, err, err_out,
126                                         diff_set);
127
128         return 0;
129
130 err_out:
131         list_for_each_entry_safe(interval, tmp, diff_set, link) {
132                 list_del(&interval->link);
133                 kfree(interval);
134         }
135
136         return err;
137 }
138
139 void usnic_uiom_put_interval_set(struct list_head *intervals)
140 {
141         struct usnic_uiom_interval_node *interval, *tmp;
142         list_for_each_entry_safe(interval, tmp, intervals, link)
143                 kfree(interval);
144 }
145
146 int usnic_uiom_insert_interval(struct rb_root *root, unsigned long start,
147                                 unsigned long last, int flags)
148 {
149         struct usnic_uiom_interval_node *interval, *tmp;
150         unsigned long istart, ilast;
151         int iref_cnt, iflags;
152         unsigned long lpivot = start;
153         int err = 0;
154         LIST_HEAD(to_add);
155         LIST_HEAD(intersection_set);
156
157         find_intervals_intersection_sorted(root, start, last,
158                                                 &intersection_set);
159
160         list_for_each_entry(interval, &intersection_set, link) {
161                 /*
162                  * Invariant - lpivot is the left edge of next interval to be
163                  * inserted
164                  */
165                 istart = interval->start;
166                 ilast = interval->last;
167                 iref_cnt = interval->ref_cnt;
168                 iflags = interval->flags;
169
170                 if (istart < lpivot) {
171                         MAKE_NODE_AND_APPEND(tmp, istart, lpivot - 1, iref_cnt,
172                                                 iflags, err, err_out, &to_add);
173                 } else if (istart > lpivot) {
174                         MAKE_NODE_AND_APPEND(tmp, lpivot, istart - 1, 1, flags,
175                                                 err, err_out, &to_add);
176                         lpivot = istart;
177                 } else {
178                         lpivot = istart;
179                 }
180
181                 if (ilast > last) {
182                         MAKE_NODE_AND_APPEND(tmp, lpivot, last, iref_cnt + 1,
183                                                 iflags | flags, err, err_out,
184                                                 &to_add);
185                         MAKE_NODE_AND_APPEND(tmp, last + 1, ilast, iref_cnt,
186                                                 iflags, err, err_out, &to_add);
187                 } else {
188                         MAKE_NODE_AND_APPEND(tmp, lpivot, ilast, iref_cnt + 1,
189                                                 iflags | flags, err, err_out,
190                                                 &to_add);
191                 }
192
193                 lpivot = ilast + 1;
194         }
195
196         if (lpivot <= last)
197                 MAKE_NODE_AND_APPEND(tmp, lpivot, last, 1, flags, err, err_out,
198                                         &to_add);
199
200         list_for_each_entry_safe(interval, tmp, &intersection_set, link) {
201                 usnic_uiom_interval_tree_remove(interval, root);
202                 kfree(interval);
203         }
204
205         list_for_each_entry(interval, &to_add, link)
206                 usnic_uiom_interval_tree_insert(interval, root);
207
208         return 0;
209
210 err_out:
211         list_for_each_entry_safe(interval, tmp, &to_add, link)
212                 kfree(interval);
213
214         return err;
215 }
216
217 void usnic_uiom_remove_interval(struct rb_root *root, unsigned long start,
218                                 unsigned long last, struct list_head *removed)
219 {
220         struct usnic_uiom_interval_node *interval;
221
222         for (interval = usnic_uiom_interval_tree_iter_first(root, start, last);
223                         interval;
224                         interval = usnic_uiom_interval_tree_iter_next(interval,
225                                                                         start,
226                                                                         last)) {
227                 if (--interval->ref_cnt == 0)
228                         list_add_tail(&interval->link, removed);
229         }
230
231         list_for_each_entry(interval, removed, link)
232                 usnic_uiom_interval_tree_remove(interval, root);
233 }
234
235 INTERVAL_TREE_DEFINE(struct usnic_uiom_interval_node, rb,
236                         unsigned long, __subtree_last,
237                         START, LAST, , usnic_uiom_interval_tree)