首页 > 其他分享 >Discriminated Unions

Discriminated Unions

时间:2023-06-17 11:33:50浏览次数:51  
标签:std return value Discriminated Unions GetDerived template VariantChoice

Discriminated Unions意为可识别的union, 一个union变量知道自己是什么类型,或者说怎么去解释这片存储空间,std::variant即为C++17引入的类型安全的union类型,本文主解析std::variant的两种实现方式;

内存

如果借助tuple来实现:

template<typename... Types>
class Variant {
    public:
        std::tuple<Types...> storage;
        unsigned char discriminator; // 用于标识当前存储类型的索引
};

通过std::get<discriminator>(storage)来获取枚举值,但tuple的内存的布局显然不满足“union为对同一片内存的不同解释”这一原则;

因此修改为:

template<typename ...Ts>
struct Typelist {};

template<typename T>
struct LargestType;

template<typename T>
struct LargestType<Typelist<T>> {
    using Type = T;
};

template<typename T1, typename T2, typename ...Tails>
struct LargestType<Typelist<T1, T2, Tails...>> {
    using Type = std::conditional_t<(sizeof(T1) > sizeof(T2)), typename LargestType<Typelist<T1, Tails...>>::Type, typename LargestType<Typelist<T2, Tails...>>::Type>;
};

template<typename... Types>
class VariantStorage {
    using LargestT = typename LargestType<Typelist<Types...>>::Type;  // 获取占用内存最大的类型
    alignas(Types...) unsigned char m_buffer[sizeof(LargestT)];
    unsigned char m_discriminator = 0;
public:
    unsigned char GetDiscriminator() const { return m_discriminator; }
    void SetDiscriminator(unsigned char d) { m_discriminator=d; }
    void* GetRawBuffer() { return m_buffer; }
    const void* GetRawBuffer() const { return m_buffer; }

    template<typename T>
    T* GetBufferAs()
    {
        return std::launder(reinterpret_cast<T*>(m_buffer));
    }
    template<typename T>
    T const* GetBufferAs() const
    {
        return std::launder(reinterpret_cast<T const*>(m_buffer));
    }
};

代码中Typelist的常用技巧后续再总结介绍,这里通过LargestType将Types中内存最大的进行筛选,其中alignas取对齐大小要求满足所有类型;

存储问题解决,然后就是如果判断当前variant属于哪种类型?其中根据判断枚举类型的方法不同,本文总结两种实现方式;

discriminator 索引判断

核心实现部分为:

template<typename Tl, typename T, size_t index>
struct FindIndexImpl {
    using Type = std::integral_constant<size_t, index>;
    static constexpr size_t value = Type::value;
};

template<typename Head, typename ...Ts, typename T, size_t index>
struct FindIndexImpl <Typelist<Head, Ts...>, T, index> {
    using Type = std::conditional_t<std::is_same_v<Head, T>, std::integral_constant<size_t, index>, FindIndexImpl<Typelist<Ts...>, T, index + 1>>;
    static constexpr size_t value = Type::value;
};

template<typename Tl, typename T>
struct FindIndex;
template<typename ...Ts, typename T>
struct FindIndex<Typelist<Ts...>, T> {
    static constexpr size_t value = FindIndexImpl<Typelist<Ts...>, T, 1>::value;
};

template <typename... Types>
class Variant;

template <typename T, typename... Types>
class VariantChoice {
    using Derived = Variant<Types...>;

    Derived& GetDerived() { return *static_cast<Derived*>(this); }

    const Derived& GetDerived() const {
        return *static_cast<const Derived*>(this);
    }

    protected:
    static constexpr unsigned Discriminator = FindIndex<Typelist<Types...>, T>::value;
public:
    VariantChoice() = default;

    VariantChoice(const T& value)
    {
        new (GetDerived().GetRawBuffer()) T(value);  // CRTP
        GetDerived().SetDiscriminator(Discriminator);
    }

    VariantChoice(T&& value)
    {
        new (GetDerived().GetRawBuffer()) T(std::move(value));
        GetDerived().SetDiscriminator(Discriminator);
    }

    bool Destroy()
    {
        if (GetDerived().GetDiscriminator() == Discriminator) {
            GetDerived().template GetBufferAs<T>()->~T();
            return true;
        }
        return false;
    }

    Derived& operator= (const T& value)
    {
        if (GetDerived().GetDiscriminator() == Discriminator) {
            *GetDerived().template GetBufferAs<T>() = value;
        } else {
            GetDerived().Destroy();
            new (GetDerived().GetRawBuffer()) T(value);
            GetDerived().SetDiscriminator(Discriminator);
        }
        return GetDerived();
    }
    // TODO 移动赋值
};

template <typename... Types>
class Variant : private VariantStorage<Types...>,
                private VariantChoice<Types, Types...>... {
  template <typename T, typename... OtherTypes>
    friend class VariantChoice;
public:
    using VariantChoice<Types, Types...>::VariantChoice...;
    using VariantChoice<Types, Types...>::operator=...;
};

以Variant<Test1, Test2, Test3>为例子,其继承关系为:

Discriminated Unions_cpp

VariantChoice中持有不同枚举类型在类型列表中的索引,即在VariantChoice的构造函数中SetDiscriminator;

则可以通过索引判断枚举类型:

template <typename T>
    bool Is() const
    {
        return this->GetDiscriminator() == VariantChoice<T, Types...>::Discriminator;
    }
Variant<Test1, Test2, Test3> v(Test1 {});
    std::cout << v.Is<Test1>() << std::endl; // 输出为1,其他类型输出为0

查询

当查询枚举值时有几种情况:

(1) 查询的类型不在类型列表中

则在判断Is时,VariantChoice<T, Types...>::Discriminator无法找到对应Discriminator,编译出错

(2) 查询的类型在类型列表中,但非当前枚举的存储类型

则判断Is返回为false,抛出异常

template <typename T>
    T& Get() &
    {
        if (this->GetDiscriminator() == 0) {
            throw EmptyVariant();
        }
        assert(Is<T>());
        return *this->template GetBufferAs<T>();
    }

销毁

variant销毁时需要调用当前类型的析构函数

~Variant()
    {   
        Destroy();
    }   

    void Destroy() {
        (VariantChoice<Types, Types...>::Destroy(), ...);
        this->SetDiscriminator(0);
    }

在赋值操作时如果类型与当前存储类型不同,也需要析构当前存储对象;

typeIndex 判断

借助std::type_index 相比上面的实现更为简易,核心实现部分为:

template<size_t ...value>
struct MaxDataSize;

template<size_t val>
struct MaxDataSize<val> : std::integral_constant<size_t, val> {};

template<size_t val1, size_t val2, size_t ...resValue>
struct MaxDataSize<val1, val2, resValue...> :
    std::integral_constant<size_t,  val1 >= val2 ?
                                    MaxDataSize<val1, resValue...>::value :
                                    MaxDataSize<val2, resValue...>::value> {};

template<typename ...Ts>
struct MaxAlignSize : std::integral_constant<int32_t, MaxDataSize<std::alignment_of<Ts>::value...>::value> {};

template<typename ...Ts>
class Variant {
public:
    enum {
        DataSize = MaxDataSize<sizeof(Ts)...>::value,
        AlignSize = MaxAlignSize<Ts...>::value
    };
    using TL = TypeList<std::decay_t<Ts>...>;
    using DataStorage = typename std::aligned_storage<DataSize, AlignSize>::type;

    Variant() : m_typeIndex(typeid(void)) {}

    template<typename T, typename = std::enable_if_t<tlContains(std::decay_t<T> {}, TL {})>>
    Variant(T&& val) : m_typeIndex(typeid(void))
    {
        using U = typename std::decay_t<T>;
        new(&m_data) U(std::forward<T>(val));
        m_typeIndex = std::type_index(typeid(U));
    }

    ~Variant()
    {
        Destory(m_typeIndex, &m_data);
    }
    
    template<typename U>
    bool IsType()
    {
        if (std::type_index(typeid(std::decay_t<U>)) == m_typeIndex) {
            return true;
        }
        return false;
    }
    
    template<typename T, typename = std::enable_if_t<tlContains(std::decay_t<T> {}, TL {})>>
    std::decay_t<T>& Get()
    {
        if (std::type_index(typeid(void)) == m_typeIndex) {
            throw EmptyVariant();
        }
        return *(reinterpret_cast<std::decay_t<T>*>(&m_data));
    }
};
// TOOD 添加std::launder修改

补充知识:存储类型size根据硬件平台的不同内存大小对齐会有差异,因此通过std::alignment_of可以获取对齐大小,td::aligned_storage<sizeof(T), std::alignment_of<T>::value>可以获取存储类型对之后的大小,配合placement new在指定内存中执行不同对象的构造;

参考资料

【1】https://en.cppreference.com/w/cpp/utility/variant

【2】C++ templates


标签:std,return,value,Discriminated,Unions,GetDerived,template,VariantChoice
From: https://blog.51cto.com/u_13137973/6504630

相关文章