moar CRTP to minimize copypasta for inbound/outbound handlers
[folly.git] / folly / wangle / channel / HandlerContext.h
1 /*
2  * Copyright 2015 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #pragma once
18
19 #include <folly/io/async/AsyncTransport.h>
20 #include <folly/futures/Future.h>
21 #include <folly/ExceptionWrapper.h>
22
23 namespace folly { namespace wangle {
24
25 template <class In, class Out>
26 class HandlerContext {
27  public:
28   virtual ~HandlerContext() {}
29
30   virtual void fireRead(In msg) = 0;
31   virtual void fireReadEOF() = 0;
32   virtual void fireReadException(exception_wrapper e) = 0;
33
34   virtual Future<void> fireWrite(Out msg) = 0;
35   virtual Future<void> fireClose() = 0;
36
37   virtual std::shared_ptr<AsyncTransport> getTransport() = 0;
38
39   virtual void setWriteFlags(WriteFlags flags) = 0;
40   virtual WriteFlags getWriteFlags() = 0;
41
42   virtual void setReadBufferSettings(
43       uint64_t minAvailable,
44       uint64_t allocationSize) = 0;
45   virtual std::pair<uint64_t, uint64_t> getReadBufferSettings() = 0;
46
47   /* TODO
48   template <class H>
49   virtual void addHandlerBefore(H&&) {}
50   template <class H>
51   virtual void addHandlerAfter(H&&) {}
52   template <class H>
53   virtual void replaceHandler(H&&) {}
54   virtual void removeHandler() {}
55   */
56 };
57
58 class PipelineContext {
59  public:
60   virtual ~PipelineContext() {}
61
62   virtual void attachPipeline() = 0;
63   virtual void detachPipeline() = 0;
64
65   virtual void attachTransport() = 0;
66   virtual void detachTransport() = 0;
67
68   template <class H, class HandlerContext>
69   void attachContext(H* handler, HandlerContext* ctx) {
70     if (++handler->attachCount_ == 1) {
71       handler->ctx_ = ctx;
72     } else {
73       handler->ctx_ = nullptr;
74     }
75   }
76
77   void link(PipelineContext* other) {
78     setNextIn(other);
79     other->setNextOut(this);
80   }
81
82  protected:
83   virtual void setNextIn(PipelineContext* ctx) = 0;
84   virtual void setNextOut(PipelineContext* ctx) = 0;
85 };
86
87 template <class In>
88 class InboundLink {
89  public:
90   virtual ~InboundLink() {}
91   virtual void read(In msg) = 0;
92   virtual void readEOF() = 0;
93   virtual void readException(exception_wrapper e) = 0;
94 };
95
96 template <class Out>
97 class OutboundLink {
98  public:
99   virtual ~OutboundLink() {}
100   virtual Future<void> write(Out msg) = 0;
101   virtual Future<void> close() = 0;
102 };
103
104 template <class P, class H, class Context>
105 class ContextImplBase : public PipelineContext {
106  public:
107   ~ContextImplBase() {}
108
109   H* getHandler() {
110     return handler_.get();
111   }
112
113   void initialize(P* pipeline, std::shared_ptr<H> handler) {
114     pipeline_ = pipeline;
115     handler_ = std::move(handler);
116   }
117
118   // PipelineContext overrides
119   void attachPipeline() override {
120     if (!attached_) {
121       this->attachContext(handler_.get(), impl_);
122       handler_->attachPipeline(impl_);
123       attached_ = true;
124     }
125   }
126
127   void detachPipeline() override {
128     handler_->detachPipeline(impl_);
129     attached_ = false;
130   }
131
132   void attachTransport() override {
133     DestructorGuard dg(pipeline_);
134     handler_->attachTransport(impl_);
135   }
136
137   void detachTransport() override {
138     DestructorGuard dg(pipeline_);
139     handler_->detachTransport(impl_);
140   }
141
142   void setNextIn(PipelineContext* ctx) override {
143     auto nextIn = dynamic_cast<InboundLink<typename H::rout>*>(ctx);
144     if (nextIn) {
145       nextIn_ = nextIn;
146     } else {
147       throw std::invalid_argument("wrong type in setNextIn");
148     }
149   }
150
151   void setNextOut(PipelineContext* ctx) override {
152     auto nextOut = dynamic_cast<OutboundLink<typename H::wout>*>(ctx);
153     if (nextOut) {
154       nextOut_ = nextOut;
155     } else {
156       throw std::invalid_argument("wrong type in setNextOut");
157     }
158   }
159
160  protected:
161   Context* impl_;
162   P* pipeline_;
163   std::shared_ptr<H> handler_;
164   InboundLink<typename H::rout>* nextIn_{nullptr};
165   OutboundLink<typename H::wout>* nextOut_{nullptr};
166
167  private:
168   bool attached_{false};
169   using DestructorGuard = typename P::DestructorGuard;
170 };
171
172 template <class P, class H>
173 class ContextImpl : public HandlerContext<typename H::rout,
174                                                  typename H::wout>,
175                     public InboundLink<typename H::rin>,
176                     public OutboundLink<typename H::win>,
177                     public ContextImplBase<P, H, HandlerContext<typename H::rout, typename H::wout>> {
178  public:
179   typedef typename H::rin Rin;
180   typedef typename H::rout Rout;
181   typedef typename H::win Win;
182   typedef typename H::wout Wout;
183
184   explicit ContextImpl(P* pipeline, std::shared_ptr<H> handler) {
185     this->impl_ = this;
186     this->initialize(pipeline, std::move(handler));
187   }
188
189   // For StaticPipeline
190   ContextImpl() {
191     this->impl_ = this;
192   }
193
194   ~ContextImpl() {}
195
196   // HandlerContext overrides
197   void fireRead(Rout msg) override {
198     DestructorGuard dg(this->pipeline_);
199     if (this->nextIn_) {
200       this->nextIn_->read(std::forward<Rout>(msg));
201     } else {
202       LOG(WARNING) << "read reached end of pipeline";
203     }
204   }
205
206   void fireReadEOF() override {
207     DestructorGuard dg(this->pipeline_);
208     if (this->nextIn_) {
209       this->nextIn_->readEOF();
210     } else {
211       LOG(WARNING) << "readEOF reached end of pipeline";
212     }
213   }
214
215   void fireReadException(exception_wrapper e) override {
216     DestructorGuard dg(this->pipeline_);
217     if (this->nextIn_) {
218       this->nextIn_->readException(std::move(e));
219     } else {
220       LOG(WARNING) << "readException reached end of pipeline";
221     }
222   }
223
224   Future<void> fireWrite(Wout msg) override {
225     DestructorGuard dg(this->pipeline_);
226     if (this->nextOut_) {
227       return this->nextOut_->write(std::forward<Wout>(msg));
228     } else {
229       LOG(WARNING) << "write reached end of pipeline";
230       return makeFuture();
231     }
232   }
233
234   Future<void> fireClose() override {
235     DestructorGuard dg(this->pipeline_);
236     if (this->nextOut_) {
237       return this->nextOut_->close();
238     } else {
239       LOG(WARNING) << "close reached end of pipeline";
240       return makeFuture();
241     }
242   }
243
244   std::shared_ptr<AsyncTransport> getTransport() override {
245     return this->pipeline_->getTransport();
246   }
247
248   void setWriteFlags(WriteFlags flags) override {
249     this->pipeline_->setWriteFlags(flags);
250   }
251
252   WriteFlags getWriteFlags() override {
253     return this->pipeline_->getWriteFlags();
254   }
255
256   void setReadBufferSettings(
257       uint64_t minAvailable,
258       uint64_t allocationSize) override {
259     this->pipeline_->setReadBufferSettings(minAvailable, allocationSize);
260   }
261
262   std::pair<uint64_t, uint64_t> getReadBufferSettings() override {
263     return this->pipeline_->getReadBufferSettings();
264   }
265
266   // InboundLink overrides
267   void read(Rin msg) override {
268     DestructorGuard dg(this->pipeline_);
269     this->handler_->read(this, std::forward<Rin>(msg));
270   }
271
272   void readEOF() override {
273     DestructorGuard dg(this->pipeline_);
274     this->handler_->readEOF(this);
275   }
276
277   void readException(exception_wrapper e) override {
278     DestructorGuard dg(this->pipeline_);
279     this->handler_->readException(this, std::move(e));
280   }
281
282   // OutboundLink overrides
283   Future<void> write(Win msg) override {
284     DestructorGuard dg(this->pipeline_);
285     return this->handler_->write(this, std::forward<Win>(msg));
286   }
287
288   Future<void> close() override {
289     DestructorGuard dg(this->pipeline_);
290     return this->handler_->close(this);
291   }
292
293  private:
294   using DestructorGuard = typename P::DestructorGuard;
295 };
296
297 }}