header_utils
Loading...
Searching...
No Matches
di_impl.h
1
4
5#pragma once
6
7namespace ghassanpl::di
8{
9 namespace detail
10 {
12
13 template <class TArg>
14 struct ArgumentResolver;
15
16 template <class TArg>
17 struct ArgumentResolver<std::shared_ptr<TArg>>
18 {
19 typedef std::shared_ptr<TArg> Type;
20
21 template <typename CONTAINER>
22 static Type Resolve(CONTAINER& container)
23 {
24 return container.template Resolve<TArg>();
25 }
26 };
27
29 /*
30 template <class TArg>
31 struct ArgumentResolver<std::unique_ptr<TArg>>
32 {
33 typedef std::unique_ptr<TArg> Type;
34
35 template <typename CONTAINER>
36 static Type Resolve(CONTAINER& container)
37 {
38 return container.template Resolve<TArg>(ForceTransient);
39 }
40 };
41 */
42
43 template <class TArg>
44 struct ArgumentResolver<std::vector<std::shared_ptr<TArg>>>
45 {
46 typedef std::vector<std::shared_ptr<TArg>> Type;
47
48 template <typename CONTAINER>
49 static Type Resolve(CONTAINER& container)
50 {
51 return container.template ResolveAll<TArg>();
52 }
53 };
54
55 template <>
56 struct ArgumentResolver<Container&>
57 {
58 typedef Container& Type;
59
60 static Type Resolve(Container& container)
61 {
62 return container;
63 }
64 };
65
66 template <class T, class TArgumentPack>
67 struct ConstructorDescriptor;
68
69 template <class T>
70 struct ConstructorDescriptor<T, std::tuple<>>
71 {
72 static std::function<std::shared_ptr<T>(Container&)> CreateFactory()
73 {
74 return [](Container&) { return std::make_shared<T>(); };
75 }
76 static T* Create(Container& c)
77 {
78 return new T;
79 }
80 };
81
82 template <typename T>
83 concept IsSupportedArgument = requires (Container& cnt) { { ArgumentResolver<T>::Resolve(cnt) }; };
84
85 template <class T, class... TAnyArgument>
86 struct ConstructorDescriptor<T, std::tuple<TAnyArgument...>>
87 {
88 static std::function<std::shared_ptr<T>(Container&)> CreateFactory()
89 {
90 return [](Container& container) {
92 };
93 }
94 static T* Create(Container& c)
95 {
97 }
98 };
99
100 struct ConstructorTypologyNotSupported
101 {
102 using Type = ConstructorTypologyNotSupported;
103 };
104
105 template <class TParent>
106 struct ArgumentResolverInvoker
107 {
108 explicit ArgumentResolverInvoker(Container& container) : mContainer(container) {}
109
110 template <class T>
111 requires (!std::is_convertible_v<TParent, T> && IsSupportedArgument<T>)
112 operator T()
113 {
114 return ArgumentResolver<T>::Resolve(mContainer);
115 }
116
117 template <class T>
118 requires (!std::is_convertible_v<TParent, T> && IsSupportedArgument<T&>)
119 operator T& ()
120 {
121 return ArgumentResolver<T&>::Resolve(mContainer);
122 }
123
124 private:
125
126 Container& mContainer;
127 };
128
129 template <class TParent>
130 struct AnyArgument
131 {
132 using Type = TParent;
133
134 template <class T>
135 requires (!std::is_convertible_v<TParent, T>&& IsSupportedArgument<T&>)
136 operator T& ()
137 {
138 }
139
140 template <class T>
141 requires (!std::is_convertible_v<TParent, T>&& IsSupportedArgument<T>)
142 operator T()
143 {
144 }
145 };
146
147 template <class T, int>
148 struct WrapAndGet : AnyArgument<T> {};
149
150 template <class, class>
151 struct ConstructorTypologyDeducer;
152
153 // Initial recursion state
154 template <typename T>
155 requires std::is_constructible_v<T>
156 struct ConstructorTypologyDeducer<T, std::integer_sequence<int>>
157 {
158 using Type = std::tuple<>;
159 };
160
161 template <class T>
162 requires (!std::is_constructible_v<T>)
163 struct ConstructorTypologyDeducer<T, std::integer_sequence<int>>
164 {
166 };
167
168 static constexpr inline size_t MaximumArgumentCount = 20;
169
170 // Common recusion state
171 template <class T, int... NthArgument>
172 requires (sizeof...(NthArgument) > 0 && sizeof...(NthArgument) < MaximumArgumentCount) && std::is_constructible_v<T, WrapAndGet<T, NthArgument>...>
173 struct ConstructorTypologyDeducer<T, std::integer_sequence<int, NthArgument...>>
174 {
175 using Type = std::tuple<WrapAndGet<T, NthArgument>...>;
176 };
177
178 template <class T, int... NthArgument>
179 requires (sizeof...(NthArgument) > 0 && sizeof...(NthArgument) < MaximumArgumentCount) && (!std::is_constructible_v<T, WrapAndGet<T, NthArgument>...>)
180 struct ConstructorTypologyDeducer<T, std::integer_sequence<int, NthArgument...>>
181 {
182 using Type = typename ConstructorTypologyDeducer<T, std::make_integer_sequence<int, sizeof...(NthArgument) + 1>>::Type;
183 };
184
185 // Last recursion state
186 template <class T, int... NthArgument>
187 requires (sizeof...(NthArgument) == MaximumArgumentCount) && std::is_constructible_v<T, WrapAndGet<T, NthArgument>...>
188 struct ConstructorTypologyDeducer<T, std::integer_sequence<int, NthArgument...>>
189 {
190 using Type = std::tuple<WrapAndGet<T, NthArgument>...>;
191 };
192
193 template <class T, int... NthArgument>
194 requires (sizeof...(NthArgument) == MaximumArgumentCount) && (!std::is_constructible_v<T, WrapAndGet<T, NthArgument>...>)
195 struct ConstructorTypologyDeducer<T, std::integer_sequence<int, NthArgument...>> : ConstructorTypologyNotSupported
196 {
197 };
198
199 template <typename T>
201 }
202
203
204 template <typename INTERFACE>
205 struct Container::ImplementationContainer
206 {
207 Lifetime CustomLifetime = Lifetime::Default;
208 std::string Name;
209
210 std::shared_ptr<INTERFACE> StrongInstancePointer;
211 std::map<std::thread::id, std::shared_ptr<INTERFACE>> ThreadInstances;
212 mutable std::weak_ptr<INTERFACE> WeakInstancePointer;
213 std::function<std::shared_ptr<INTERFACE>(Container&)> mFactory;
214 std::function<void(Container&, std::shared_ptr<INTERFACE>)> mOnCreate;
215
217 void Set(std::string_view name) { Name = std::string{ name }; }
218 void Set(Lifetime lifetime) { CustomLifetime = lifetime; }
219 void Set(DefaultImplementationStruct) { }
220 template <std::derived_from<INTERFACE> T>
221 void Set(std::function<std::shared_ptr<T>(Container&)> factory)
222 {
223 mFactory = [factory = std::move(factory)](Container& container) {
224 return container.Instantiate<INTERFACE>(factory);
225 };
226 }
227 void Set(std::shared_ptr<INTERFACE> instance) { StrongInstancePointer = std::move(instance); }
228 void Set(INTERFACE* instance) { StrongInstancePointer = std::shared_ptr<INTERFACE>{ std::shared_ptr<INTERFACE>{}, instance }; }
229 void Set(std::function<void(Container&, std::shared_ptr<INTERFACE>)> on_create) { mOnCreate = std::move(on_create); }
230
231 std::shared_ptr<INTERFACE> Resolve(Container& container, Lifetime lifetime)
232 {
233 if (CustomLifetime != Lifetime::Default)
234 lifetime = CustomLifetime;
235
236 if (StrongInstancePointer)
237 return StrongInstancePointer;
238
239 if (lifetime == Lifetime::ThreadSingleton)
240 {
241 auto& ptr = ThreadInstances[std::this_thread::get_id()];
242 if (!ptr)
244 return ptr;
245 }
246 else if (lifetime == Lifetime::WeakSingleton)
247 {
248 if (auto result = WeakInstancePointer.lock(); result)
249 return result;
250 auto ptr = Create(container);
251 WeakInstancePointer = ptr;
252 return ptr;
253 }
254 else if (lifetime == Lifetime::InstanceSingleton)
255 return StrongInstancePointer = Create(container);
256 else
257 return Create(container);
258 }
259
260 private:
261
262 void ResetInstance()
263 {
264 StrongInstancePointer.reset();
265 WeakInstancePointer.reset();
266 ThreadInstances.clear();
267 }
268
269 std::shared_ptr<INTERFACE> Create(Container& container) const
270 {
271 auto obj = mFactory(container);
272
273 if (!container.mDebugStore.contains(obj.get()))
274 container.mDebugStore.emplace(obj.get(), std::pair<std::type_index, std::weak_ptr<void>>{ typeid(*obj), std::weak_ptr<void>{obj} });
275
276 if (mOnCreate)
277 {
278 container.ReportCreation(obj, [this](Container& container, std::shared_ptr<void> instance) {
279 mOnCreate(container, std::static_pointer_cast<INTERFACE>(std::move(instance)));
280 });
281 }
282 return obj;
283 }
284
285 };
286
287 template <typename T>
288 concept has_default_lifetime = requires { { T::DefaultLifetime } -> std::convertible_to<Lifetime>; };
289
290 template <typename INTERFACE>
291 struct Container::InterfaceContainer : Container::BaseInterfaceContainer
292 {
293 template <typename I = INTERFACE>
294 static constexpr Lifetime GetDeclaredLifetime()
295 {
296 if constexpr (has_default_lifetime<I>)
297 return I::DefaultLifetime;
298 else
299 return Lifetime::Default;
300 }
301
302 InterfaceContainer() : BaseInterfaceContainer(GetDeclaredLifetime<INTERFACE>()) {}
303
304 template <typename IMPLEMENTATION, typename... ARGS>
305 void RegisterImplementationType(ARGS&&... args)
306 {
307 if (mImplementations.contains(typeid(IMPLEMENTATION)))
308 throw "already registered";
309
310 auto& impl = mImplementations[typeid(IMPLEMENTATION)];
311 if constexpr (is_same_as_any_v<DefaultImplementationStruct, ARGS...>)
312 mImplementationsInDeclarationOrder.insert(mImplementationsInDeclarationOrder.begin(), &impl);
313 else
314 mImplementationsInDeclarationOrder.push_back(&impl);
315
316 impl.Set(GetDeclaredLifetime<IMPLEMENTATION>());
317 (impl.Set(std::forward<ARGS>(args)), ...);
318 }
319
320 template <typename IMPLEMENTATION>
321 ImplementationContainer<INTERFACE>* GetImplementationContainer()
322 {
323 auto it = mImplementations.find(typeid(IMPLEMENTATION));
324 if (it == mImplementations.end())
325 return nullptr;
326 return &it->second;
327 }
328
329 std::shared_ptr<INTERFACE> Resolve(Container& container, Lifetime lifetime)
330 {
331 if (DefaultLifetime != Lifetime::Default)
332 lifetime = DefaultLifetime;
333
334 if (mImplementationsInDeclarationOrder.empty())
335 return {};
336
337 return mImplementationsInDeclarationOrder.back()->Resolve(container, lifetime);
338 }
339
340 std::vector<std::shared_ptr<INTERFACE>> ResolveAll(Container& container, Lifetime lifetime)
341 {
342 std::vector<std::shared_ptr<INTERFACE>> result;
343
344 for (auto& [type, impl] : mImplementations)
345 result.push_back(impl.Resolve(container, lifetime));
346
347 return result;
348 }
349
350 private:
351
352 std::map<std::type_index, ImplementationContainer<INTERFACE>> mImplementations;
353 std::vector<ImplementationContainer<INTERFACE>*> mImplementationsInDeclarationOrder;
354
355 };
356
357 template <typename INTERFACE>
358 bool Container::HasAnyImplementationsOf() const
359 {
360 if (auto it = mContainers.find(typeid(INTERFACE)); it != mContainers.end())
361 return !it->second->mImplementations.empty();
362 return false;
363 }
364
365 template<typename INTERFACE, typename IMPLEMENTATION, typename ...ARGS>
366 void Container::RegisterType(ARGS&& ...args)
367 {
370 static_assert(std::is_base_of_v<INTERFACE, IMPLEMENTATION>, "Implementation class must inherit from interface class");
371 static_assert(!std::is_abstract_v<IMPLEMENTATION>, "Implementation cannot be abstract");
372 if constexpr (std::is_base_of_v<INTERFACE, IMPLEMENTATION> && !std::is_abstract_v<IMPLEMENTATION>)
373 {
374 constexpr auto instance_given = is_same_as_any_v<std::shared_ptr<IMPLEMENTATION>, ARGS...>;
375 constexpr auto interface_factory = is_same_as_any_v<std::function<std::shared_ptr<INTERFACE>(Container&)>, ARGS...>;
376 constexpr auto impl_factory = is_same_as_any_v<std::function<std::shared_ptr<IMPLEMENTATION>(Container&)>, ARGS...>;
377 static_assert(!(instance_given && (interface_factory || impl_factory)), "Cannot register type with both factory and instance");
378 if constexpr (instance_given || interface_factory || impl_factory)
379 GetInterfaceContainer<INTERFACE>().RegisterImplementationType<IMPLEMENTATION>(std::forward<ARGS>(args)...);
380 else
381 GetInterfaceContainer<INTERFACE>().RegisterImplementationType<IMPLEMENTATION>(detail::ConstructorDescriptorForClass<IMPLEMENTATION>::CreateFactory(), std::forward<ARGS>(args)...);
382 }
383 }
384
385 template<typename INTERFACE>
386 std::shared_ptr<INTERFACE> Container::Resolve()
387 {
388 return GetInterfaceContainer<INTERFACE>().Resolve(*this, DefaultLifetime);
389 }
390
391 template<typename INTERFACE>
392 std::shared_ptr<INTERFACE> Container::ResolveByName(std::string_view name)
393 {
394 auto& interface_container = GetInterfaceContainer<INTERFACE>();
395 for (auto& impl : interface_container.mImplementations)
396 {
397 if (impl.Name == name)
398 return impl.Resolve(*this, DefaultLifetime);
399 }
400 return {};
401 }
402
403 template<typename INTERFACE>
404 std::vector<std::shared_ptr<INTERFACE>> Container::ResolveAll()
405 {
406 return GetInterfaceContainer<INTERFACE>().ResolveAll(*this, DefaultLifetime);
407 }
408
409 template <typename TYPE>
410 std::shared_ptr<TYPE> Container::Create()
411 {
412 return std::shared_ptr<TYPE>{ detail::ConstructorDescriptorForClass<TYPE>::Create(*this) };
413 }
414
415 template <typename TYPE>
416 std::unique_ptr<TYPE> Container::CreateRaw()
417 {
418 return std::unique_ptr<TYPE>{ detail::ConstructorDescriptorForClass<TYPE>::Create(*this) };
419 }
420
421 template<typename INTERFACE>
422 Container::InterfaceContainer<INTERFACE>& Container::GetInterfaceContainer()
423 {
424 auto& container = mContainers[typeid(INTERFACE)];
425 if (!container)
426 container = std::make_unique<InterfaceContainer<INTERFACE>>();
427 return *static_cast<InterfaceContainer<INTERFACE>*>(container.get());
428 }
429
430 template<typename INTERFACE, typename IMPLEMENTATION>
431 Container::ImplementationContainer<INTERFACE>* Container::GetImplementationContainer()
432 {
433 return GetInterfaceContainer<INTERFACE>().GetImplementationContainer<IMPLEMENTATION>();
434 }
435
436 template<typename INSTANCE>
437 void Container::ReportCreation(std::shared_ptr<INSTANCE> const& obj, std::function<void(Container&, std::shared_ptr<void>)> func)
438 {
439 for (auto& [ptr, callback] : mCreationsToReport)
440 {
441 if (ptr.get() == obj.get())
442 return;
443 }
444 mCreationsToReport.emplace_back(static_pointer_cast<void>(obj), std::move(func));
445 }
446
447 template<typename INTERFACE, typename T>
448 std::shared_ptr<INTERFACE> Container::Instantiate(T& factory)
449 {
450 if (find(mResolutionStack.begin(), mResolutionStack.end(), typeid(INTERFACE)) != mResolutionStack.end())
451 throw "circular dependency";
452 mResolutionStack.push_back(typeid(INTERFACE));
453
454 auto result = factory(*this);
455
456 mResolutionStack.pop_back();
457 if (mResolutionStack.empty())
458 ReportAwaitingCreations();
459
460 return std::static_pointer_cast<INTERFACE>(std::move(result));
461 }
462
463}
constexpr auto bit_count
Equal to the number of bits in the type.
Definition bits.h:33
The below code is based on Sun's libm library code, which is licensed under the following license:
TODO: Split into ContainerBuilder and Container (or [Dependency]Registry and [Dependency]Container) O...
Definition di.h:51
std::shared_ptr< TYPE > Create()
Other.
Definition di_impl.h:410
std::shared_ptr< INTERFACE > Resolve()
Resolves.
Definition di_impl.h:386