# Copyright 2021-2024 The Khronos Group
# SPDX-License-Identifier: Apache-2.0

import sys
import os
import json

import merge_anari

def write_anari_enums(filename, anari):
    os.makedirs(os.path.dirname(filename), exist_ok=True)
    with open(filename, mode='w') as f:

        f.write("// Copyright 2024 The Khronos Group\n")
        f.write("// SPDX-License-Identifier: Apache-2.0\n\n")
        f.write("// This file was generated by "+os.path.basename(__file__)+"\n")
        f.write("// Don't make changes to this directly\n\n")

        f.write("""
#pragma once
#ifdef __cplusplus
struct ANARIDataType
{
  ANARIDataType() = default;
  constexpr ANARIDataType(int v) noexcept : value(v) {}
  constexpr operator int() const noexcept { return value; }
 private:
  int value;
};
constexpr bool operator<(ANARIDataType v1, ANARIDataType v2)
{
  return static_cast<int>(v1) < static_cast<int>(v2);
}
#define ANARI_DATA_TYPE_DEFINE(v) ANARIDataType(v)
#else
typedef int ANARIDataType;
#define ANARI_DATA_TYPE_DEFINE(v) v
#endif
""")

        # generate enums
        for enum in anari['enums']:
            if enum['name'] != 'ANARIDataType':
                f.write('\ntypedef '+enum['baseType']+' '+enum['name']+';\n')
            for value in enum['values']:
                if enum['name'] == 'ANARIDataType':
                    f.write('#define '+value['name']+' ANARI_DATA_TYPE_DEFINE('+str(value['value'])+')\n')
                else:
                    f.write('#define '+value['name']+' '+str(value['value'])+'\n')

def write_anari_header(filename, anari):
    os.makedirs(os.path.dirname(filename), exist_ok=True)

    with open(filename, mode='w') as f:
        f.write("// Copyright 2021-2024 The Khronos Group\n")
        f.write("// SPDX-License-Identifier: Apache-2.0\n\n")
        f.write("// This file was generated by "+os.path.basename(__file__)+"\n")
        f.write("// Don't make changes to this directly\n\n")

        f.write("""
#pragma once

#include <stdint.h>
#include <sys/types.h>

#ifndef NULL
#if __cplusplus >= 201103L
#define NULL nullptr
#else
#define NULL 0
#endif
#endif

#define ANARI_INVALID_HANDLE NULL

#include "frontend/anari_sdk_version.h"
#include "frontend/anari_enums.h"

#ifdef _WIN32
#ifdef ANARI_STATIC_DEFINE
#define ANARI_INTERFACE
#else
#ifdef anari_EXPORTS
#define ANARI_INTERFACE __declspec(dllexport)
#else
#define ANARI_INTERFACE __declspec(dllimport)
#endif
#endif
#elif defined __GNUC__
#define ANARI_INTERFACE __attribute__((__visibility__("default")))
#else
#define ANARI_INTERFACE
#endif

#ifdef __GNUC__
#define ANARI_DEPRECATED __attribute__((deprecated))
#elif defined(_MSC_VER)
#define ANARI_DEPRECATED __declspec(deprecated)
#else
#define ANARI_DEPRECATED
#endif

#ifdef __cplusplus
// C++ DOES support default initializers
#define ANARI_DEFAULT_VAL(a) = a
#else
/* C99 does NOT support default initializers, so we use this macro
   to define them away */
#define ANARI_DEFAULT_VAL(a)
#endif
""")

        # generate opaque types
        f.write('#ifdef __cplusplus\n')
        f.write('namespace anari {\n')
        f.write('namespace api {\n')

        for opaque in anari['opaqueTypes']:
            if 'parent' in opaque:
                f.write('struct '+opaque['name'][5:]+' : public '+opaque['parent'][5:]+' {};\n')
            else:
                f.write('struct '+opaque['name'][5:]+' {};\n')

        f.write('} // namespace api\n')
        f.write('} // namespace anari\n')

        for opaque in anari['opaqueTypes']:
            f.write('typedef anari::api::'+opaque['name'][5:]+' *'+opaque['name']+';\n')

        f.write('#else\n')

        for opaque in anari['opaqueTypes']:
            f.write('typedef void* '+opaque['name']+';\n')

        f.write('#endif\n')

        f.write("""
#ifdef __cplusplus
extern "C" {
#endif

""")

        # generate structs
        for struct in anari['structs']:
            f.write('typedef struct {\n')

            if struct['members']:
                for member in struct['members']:
                    if 'elements' in member:
                        f.write('  '+member['type']+' '+member['name']+'['+str(member['elements'])+'];\n')
                    else:
                        f.write('  '+member['type']+' '+member['name']+';\n')

            f.write('} '+struct['name']+';\n')

        f.write('\n')
        # generate function prototypes
        for fun in anari['functionTypedefs']:
            f.write('typedef '+fun['returnType']+' (*'+fun['name']+')(')

            if fun['arguments']:
                arg = fun['arguments'][0]
                f.write(arg['type']+' '+arg['name'])
                for arg in fun['arguments'][1:]:
                    if 'default' in arg:
                        f.write(', '+arg['type']+' '+arg['name']+' ANARI_DEFAULT_VAL('+arg['default']+')')
                    else:
                        f.write(', '+arg['type']+' '+arg['name'])

            f.write(');\n')

        f.write('\n')
        # generate function prototypes
        for fun in anari['functions']:
            f.write('ANARI_INTERFACE '+fun['returnType']+' '+fun['name']+'(')

            if fun['arguments']:
                arg = fun['arguments'][0]
                f.write(arg['type']+' '+arg['name'])
                for arg in fun['arguments'][1:]:
                    if 'default' in arg:
                        f.write(', '+arg['type']+' '+arg['name']+' ANARI_DEFAULT_VAL('+arg['default']+')')
                    else:
                        f.write(', '+arg['type']+' '+arg['name'])

            f.write(');\n')

        f.write("""
#ifdef __cplusplus
} // extern "C"
#endif
""")


def write_anari_type_query_helper(filename, anari):
    os.makedirs(os.path.dirname(filename), exist_ok=True)

    with open(filename, mode='w') as f:
        f.write("// Copyright 2021-2024 The Khronos Group\n")
        f.write("// SPDX-License-Identifier: Apache-2.0\n\n")
        f.write("// This file was generated by "+os.path.basename(__file__)+"\n")
        f.write("// Don't make changes to this directly\n\n")

        f.write("""
#pragma once

#include <anari/anari.h>
#include <math.h>
#include <stdint.h>

inline float anari_unit_clamp(float x) {
    if(x < -1.0f) {
        return -1.0f;
    } else if(x > 1.0f) {
        return 1.0f;
    } else {
        return x;
    }
}

inline int64_t anari_fixed_clamp(float x, int64_t max) {
    if(x <= -1.0f) {
        return -max;
    } else if(x >= 1.0f) {
        return max;
    } else {
        return (int64_t)(x*max);
    }
}

inline uint64_t anari_ufixed_clamp(float x, uint64_t max) {
    if(x <= 0.0f) {
        return 0u;
    } else if(x >= 1.0f) {
        return max;
    } else {
        return (uint64_t)(x*max);
    }
}

inline float anari_from_srgb(uint8_t x0) {
    float x = x0/(float)UINT8_MAX;
    if(x<=0.04045f) {
        return x*0.0773993808f;
    } else {
        return powf((x+0.055f)*0.94786729857f, 2.4f);
    }
}

inline uint8_t anari_to_srgb(float x) {
    if(x >= 1.0f) {
        return UINT8_MAX;
    } else if(x<=0.0f) {
        return 0u;
    } else if(x<=0.0031308f) {
        return (uint8_t)(x*12.92f*UINT8_MAX);
    } else {
        return (uint8_t)((powf(x*1.055f, 1.0f/2.4f)-0.055f)*UINT8_MAX);
    }
}


#ifdef __cplusplus

#include <utility>

namespace anari {

template<int type>
struct ANARITypeProperties { };

""")

        special_conversions = {}
        special_conversions['ANARI_UFIXED8_RGBA_SRGB'] = ("""        dst[0] = anari_from_srgb(src[0]);
        dst[1] = anari_from_srgb(src[1]);
        dst[2] = anari_from_srgb(src[2]);
        dst[3] = anari_unit_clamp(src[3]);
""", """        dst[0] = anari_to_srgb(src[0]);
        dst[1] = anari_to_srgb(src[1]);
        dst[2] = anari_to_srgb(src[2]);
        dst[3] = (base_type)anari_ufixed_clamp(float(src[3]), UINT8_MAX);
""")

        special_conversions['ANARI_UFIXED8_RGB_SRGB'] = ("""        dst[0] = anari_from_srgb(src[0]);
        dst[1] = anari_from_srgb(src[1]);
        dst[2] = anari_from_srgb(src[2]);
        dst[3] = 1.0f;
""", """        dst[0] = anari_to_srgb(src[0]);
        dst[1] = anari_to_srgb(src[1]);
        dst[2] = anari_to_srgb(src[2]);
""")

        special_conversions['ANARI_UFIXED8_RA_SRGB'] = ("""        dst[0] = anari_from_srgb(src[0]);
        dst[1] = 0;
        dst[2] = 0;
        dst[3] = anari_unit_clamp(src[1]/(float)UINT8_MAX);
""", """        dst[0] = anari_to_srgb(src[0]);
        dst[1] = (base_type)anari_ufixed_clamp(src[3], UINT8_MAX);
""")

        special_conversions['ANARI_UFIXED8_R_SRGB'] = ("""        dst[0] = anari_from_srgb(src[0]);
        dst[1] = 0.0f;
        dst[2] = 0.0f;
        dst[3] = 1.0f;
""", """        dst[0] = anari_to_srgb(src[0]);
""")

        enums = next(x for x in anari['enums'] if x['name']=='ANARIDataType')
        for enum in enums['values']:
            f.write('template<>\n')
            f.write('struct ANARITypeProperties<'+enum['name']+'> {\n')
            f.write('    using base_type = '+enum['baseType']+';\n')
            f.write('    static const int components = '+str(enum['elements'])+';\n')
            f.write('    static const bool normalized = '+('true' if enum['normalized'] else 'false')+';\n')
            f.write('    using array_type = base_type['+str(enum['elements'])+'];\n')
            f.write('    static constexpr const char* enum_name = "'+enum['name']+'";\n')
            f.write('    static constexpr const char* type_name = "'+enum['baseType']+'";\n')
            f.write('    static constexpr const char* array_name = "'+enum['baseType']+'['+str(enum['elements'])+']";\n')
            f.write('    static constexpr const char* var_name = "var'+enum['name'][6:].lower()+'";\n')

            if enum['value'] >= 1000: # conversions for numerical types
                f.write('    static void toFloat4(float *dst, const base_type *src) {\n')
                if enum['name'] in special_conversions:
                    f.write(special_conversions[enum['name']][0])
                else:
                    for i in range(0, 4):
                        if i < enum['elements']:
                            if 'FIXED' in enum['name']:
                                f.write('        dst['+str(i)+'] = anari_unit_clamp((float)src['+str(i)+']/(float)'+enum['baseType'].upper()[0:-2]+'_MAX);\n')
                            else:
                                f.write('        dst['+str(i)+'] = (float)src['+str(i)+'];\n')
                        else:
                            f.write('        dst['+str(i)+'] = '+('1.0f' if i==3 else '0.0f')+';\n')
                f.write('    }\n')

                f.write('    static void fromFloat4(base_type *dst, const float *src) {\n')
                if enum['name'] in special_conversions:
                    f.write(special_conversions[enum['name']][1])
                else:
                    for i in range(0, min(4, enum['elements'])):
                        if 'UFIXED' in enum['name']:
                            f.write('        dst['+str(i)+'] = (base_type)anari_ufixed_clamp(src['+str(i)+'], '+enum['baseType'].upper()[0:-2]+'_MAX);\n')
                        elif 'FIXED' in enum['name']:
                            f.write('        dst['+str(i)+'] = (base_type)anari_fixed_clamp(src['+str(i)+'], '+enum['baseType'].upper()[0:-2]+'_MAX);\n')
                        else:
                            f.write('        dst['+str(i)+'] = (base_type)src['+str(i)+'];\n')
                f.write('    }\n')
            else: # no op for non numeric types
                f.write('    static void fromFloat4(base_type *dst, const float *src) { }\n')
                f.write('    static void toFloat4(float *dst, const base_type *src) { }\n')



            f.write('};\n')

        f.write("""
template <typename R, template<int> class F, typename... Args>
R anariTypeInvoke(ANARIDataType type, Args&&... args) {
    switch (type) {
""")
        for enum in enums['values']:
            f.write('        case '+str(enum['value'])+': return F<'+str(enum['value'])+'>()(std::forward<Args>(args)...);\n')
        f.write('        default: return F<ANARI_UNKNOWN>()(std::forward<Args>(args)...);\n')
        f.write('    }\n')
        f.write('}\n')

        f.write('#endif\n')



        f.write("""
inline size_t sizeOf(ANARIDataType type) {
    switch (type) {
""")
        for enum in enums['values']:
            f.write('        case '+enum['name']+': return sizeof('+enum['baseType']+')*'+str(enum['elements'])+';\n')
        f.write('        default: return 4;\n')
        f.write('    }\n')
        f.write('}\n')



        f.write("""
inline size_t componentsOf(ANARIDataType type) {
    switch (type) {
""")
        for enum in enums['values']:
            f.write('        case '+enum['name']+': return '+str(enum['elements'])+';\n')
        f.write('        default: return 1;\n')
        f.write('    }\n')
        f.write('}\n')



        f.write("""
inline const char* toString(ANARIDataType type) {
    switch (type) {
""")
        for enum in enums['values']:
            f.write('        case '+enum['name']+': return "'+enum['name']+'";\n')
        f.write('        default: return "ANARI_UNKNOWN";\n')
        f.write('    }\n')
        f.write('}\n')


        f.write("""
inline const char* typenameOf(ANARIDataType type) {
    switch (type) {
""")
        for enum in enums['values']:
            f.write('        case '+enum['name']+': return "'+enum['baseType']+'";\n')
        f.write('        default: return "ANARI_UNKNOWN";\n')
        f.write('    }\n')
        f.write('}\n')


        f.write("""
inline const char* varnameOf(ANARIDataType type) {
    switch (type) {
""")
        for enum in enums['values']:
            f.write('        case '+enum['name']+': return "var'+enum['name'][6:].lower()+'";\n')
        f.write('        default: return "ANARI_UNKNOWN";\n')
        f.write('    }\n')
        f.write('}\n')


        f.write("""
inline int isNormalized(ANARIDataType type) {
    switch (type) {
""")
        for enum in enums['values']:
            f.write('        case '+enum['name']+': return '+str(int(enum['normalized']))+';\n')
        f.write('        default: return 0;\n')
        f.write('    }\n')
        f.write('}\n')

        f.write('}\n')

output = sys.argv[1]
trees = [json.load(open(x)) for x in sys.argv[2:]]
api = trees[0]
for x in trees[1:]:
    merge_anari.merge(api, x)

filename = os.path.basename(output)
if filename == 'anari.h':
    write_anari_header(output, api)
elif filename == 'anari_enums.h':
    write_anari_enums(output, api)
elif filename == 'type_utility.h':
    write_anari_type_query_helper(output, api)
