ConstValueProvider.java

/*
 * Copyright 2016 Providence Authors
 *
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements. See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership. The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License. You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied. See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
package net.morimekta.providence.reflect.util;

import net.morimekta.providence.PEnumBuilder;
import net.morimekta.providence.PEnumValue;
import net.morimekta.providence.PMessage;
import net.morimekta.providence.PMessageBuilder;
import net.morimekta.providence.PType;
import net.morimekta.providence.descriptor.PDeclaredDescriptor;
import net.morimekta.providence.descriptor.PDescriptor;
import net.morimekta.providence.descriptor.PEnumDescriptor;
import net.morimekta.providence.descriptor.PField;
import net.morimekta.providence.descriptor.PList;
import net.morimekta.providence.descriptor.PMap;
import net.morimekta.providence.descriptor.PMessageDescriptor;
import net.morimekta.providence.descriptor.PSet;
import net.morimekta.providence.descriptor.PValueProvider;
import net.morimekta.providence.reflect.parser.ThriftException;
import net.morimekta.providence.reflect.parser.ThriftLexer;
import net.morimekta.providence.reflect.parser.ThriftToken;
import net.morimekta.providence.reflect.parser.ThriftTokenType;
import net.morimekta.providence.types.TypeReference;
import net.morimekta.providence.types.TypeRegistry;
import net.morimekta.util.Binary;
import net.morimekta.util.Strings;
import net.morimekta.util.collect.UnmodifiableList;
import net.morimekta.util.collect.UnmodifiableMap;
import net.morimekta.util.collect.UnmodifiableSet;
import net.morimekta.util.json.JsonException;
import net.morimekta.util.json.JsonToken;
import net.morimekta.util.json.JsonTokenizer;
import net.morimekta.util.lexer.LexerException;
import net.morimekta.util.lexer.Tokenizer;
import net.morimekta.util.lexer.TokenizerRepeater;
import net.morimekta.util.lexer.UncheckedLexerException;

import javax.annotation.Nonnull;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.UncheckedIOException;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.atomic.AtomicReference;

import static java.nio.charset.StandardCharsets.US_ASCII;
import static net.morimekta.providence.reflect.parser.ThriftTokenizer.kNull;
import static net.morimekta.providence.types.TypeReference.parseType;
import static net.morimekta.providence.util.MessageUtil.coerceStrict;

/**
 * A value provider for thrift constants.
 */
public class ConstValueProvider implements PValueProvider<Object> {
    private final TypeRegistry      registry;
    private final String            programName;
    private final TypeReference     constType;
    private final List<ThriftToken> constTokens;
    private AtomicReference<Object> parsedValue;

    public ConstValueProvider(@Nonnull TypeRegistry registry,
                              @Nonnull String programName,
                              @Nonnull TypeReference constType,
                              @Nonnull List<ThriftToken> constTokens) {
        this.registry = registry;
        this.programName = programName;
        this.constType = constType;
        this.constTokens = constTokens;
        this.parsedValue = null;
    }

    @Override
    public Object get() {
        if (parsedValue == null) {
            PDescriptor type = registry
                    .getTypeProvider(constType, Collections.emptyMap())
                    .descriptor();
            Tokenizer<ThriftTokenType, ThriftToken> tokenizer = new TokenizerRepeater<>(constTokens);
            ThriftLexer                             lexer     = new ThriftLexer(tokenizer);
            try {
                parsedValue = new AtomicReference<>(
                        parseTypedValue(lexer.expect("const value"), lexer, type, true));
            } catch (LexerException e) {
                throw new UncheckedLexerException(e);
            } catch (IOException e) {
                throw new UncheckedIOException(e.getMessage(), e);
            }
        }

        return parsedValue.get();
    }

    /**
     * Parse JSON object as a message.
     *
     * @param lexer     The thrift lexer.
     * @param type      The message type.
     * @param <Message> Message generic type.
     * @return The parsed message.
     */
    private <Message extends PMessage<Message>>
    Message parseMessage(ThriftLexer lexer,
                         PMessageDescriptor<Message> type) throws IOException {
        PMessageBuilder<Message> builder = type.builder();

        ThriftToken token = lexer.expect("message field or end");
        while (!token.isSymbol(ThriftToken.kMessageEnd)) {
            if (token.type() != ThriftTokenType.STRING) {
                throw lexer.failure(token, "Invalid field name token");
            }
            PField<?> field = type.findFieldByName(token.decodeString(true));
            if (field == null) {
                throw lexer.failure(
                        token, "No such field in %s: %s",
                        type.getQualifiedName(), token.decodeString(true));
            }
            lexer.expectSymbol("message key-value sep", ThriftToken.kKeyValueSep);

            builder.set(field.getId(),
                        parseTypedValue(lexer.expect("parsing field value"),
                                        lexer,
                                        field.getDescriptor(),
                                        false));

            token = lexer.expect("message sep, field or end");
            if (token.isSymbol(ThriftToken.kLineSep1) || token.isSymbol(ThriftToken.kLineSep2)) {
                token = lexer.expect("message field or end");
            }
        }

        return builder.build();
    }

    private Object parseTypedValue(ThriftToken token,
                                   ThriftLexer lexer,
                                   PDescriptor valueType,
                                   boolean allowNull)
            throws IOException {
        // Enum.VALUE
        // program.Enum.VALUE
        if (token.isQualifiedIdentifier() ||
            token.isDoubleQualifiedIdentifier()) {
            // Possible enum value. First strip after last '.' to get enum type name.
            String typeName = token.toString().replaceAll("\\.[^.]+$", "");

            Optional<PDeclaredDescriptor<?>> desc;
            if (typeName.equals(valueType.getName()) &&
                valueType.getType() == PType.ENUM) {
                desc = Optional.of((PDeclaredDescriptor<?>) valueType);
            } else {
                TypeReference ref = parseType(this.programName, typeName);
                desc = registry.getDeclaredType(ref);
            }

            if (desc.isPresent()) {
                if (!(desc.get() instanceof PEnumDescriptor)) {
                    throw new IllegalArgumentException("Not an enum type " + desc.get().getQualifiedName());
                }

                String        valueName = token.toString().replaceAll("^.*\\.", "");
                PEnumValue<?> value     = ((PEnumDescriptor<?>) desc.get()).findByName(valueName);
                if (value != null) {
                    return coerceStrict(valueType, value).orElseThrow(() -> new IllegalArgumentException("Non-matching enum value"));
                } else if (allowNull) {
                    return null;
                } else {
                    throw new IllegalArgumentException("No such " + desc.get().getQualifiedName() + " value " + valueName);
                }
            }
        }

        // kConstName
        // program.kConstName
        if (token.isIdentifier() || token.isQualifiedIdentifier()) {
            if (kNull.equals(token.toString())) {
                if (allowNull) {
                    return null;
                }
            }

            Optional<Object> optional = registry.getConstantValue(parseType(this.programName, token.toString()));
            if (optional.isPresent()) {
                return coerceStrict(valueType, optional.get()).orElse(null);
            }
        }

        // Direct value.
        switch (valueType.getType()) {
            case BOOL:
                if (token.isIdentifier()) {
                    return Boolean.parseBoolean(token.toString());
                } else if (token.isInteger()) {
                    return token.parseInteger() != 0L;
                }
                throw lexer.failure(token, "Not boolean value: %s", token.toString());
            case BYTE:
                if (token.isInteger()) {
                    return (byte) token.parseInteger();
                }
                return (byte) findEnumValue(token.toString(), token, lexer, "byte");
            case I16:
                if (token.isInteger()) {
                    return (short) token.parseInteger();
                }
                return (short) findEnumValue(token.toString(), token, lexer, "i16");
            case I32:
                if (token.isInteger()) {
                    return (int) token.parseInteger();
                }
                return findEnumValue(token.toString(), token, lexer, "i32");
            case I64:
                if (token.isInteger()) {
                    return token.parseInteger();
                }
                return (long) findEnumValue(token.toString(), token, lexer, "i64");
            case DOUBLE:
                if (token.type() == ThriftTokenType.NUMBER) {
                    return token.parseDouble();
                }
                throw lexer.failure(token, token + " is not a valid double value.");
            case STRING:
                if (token.type() == ThriftTokenType.STRING) {
                    return token.decodeString(true);
                } else if (allowNull && token.toString().equals(kNull)) {
                    return null;
                }
                throw lexer.failure(token, "Not a valid string value.");
            case BINARY:
                if (token.type() == ThriftTokenType.STRING) {
                    return parseBinary(token.substring(1, -1)
                                            .toString());
                } else if (allowNull && token.toString().equals(kNull)) {
                    return null;
                }
                throw lexer.failure(token, "Not a valid binary value.");
            case ENUM: {
                // Enum reference already handled.
                PEnumBuilder<?> eb   = ((PEnumDescriptor<?>) valueType).builder();
                String          name = token.toString();
                if (Strings.isInteger(name)) {
                    Object ev = eb.setById(Integer.parseInt(name)).build();
                    if (ev == null) {
                        if (allowNull && token.toString().equals(kNull)) {
                            return null;
                        }
                        throw lexer.failure(token,
                                                "No such " + valueType.getQualifiedName() + " enum value \"" + name);
                    }
                    return ev;
                }
                throw lexer.failure(token, "Not valid enum reference '" + name + "'");
            }
            case MESSAGE: {
                if (token.isSymbol(ThriftToken.kMessageStart)) {
                    return parseMessage(lexer, (PMessageDescriptor<?>) valueType);
                } else if (allowNull && token.toString().equals(kNull)) {
                    // messages can be null values in constants.
                    return null;
                }
                throw lexer.failure(token, "Not a valid message start.");
            }
            case LIST: {
                PDescriptor                      itemType = ((PList<?>) valueType).itemDescriptor();
                UnmodifiableList.Builder<Object> list     = UnmodifiableList.builder();

                if (!token.isSymbol(ThriftToken.kListStart)) {
                    throw lexer.failure(token, "Expected list start, found " + token.toString());
                }
                token = lexer.expect("list item or end");
                while (!token.isSymbol(ThriftToken.kListEnd)) {
                    list.add(parseTypedValue(token, lexer, itemType, false));
                    token = lexer.expect("list item, sep or end");
                    if (token.isSymbol(ThriftToken.kLineSep1) || token.isSymbol(ThriftToken.kLineSep2)) {
                        token = lexer.expect("list item or end");
                    }
                }

                return list.build();
            }
            case SET: {
                PDescriptor                     itemType = ((PSet<?>) valueType).itemDescriptor();
                UnmodifiableSet.Builder<Object> set      = UnmodifiableSet.builder();

                if (!token.isSymbol(ThriftToken.kListStart)) {
                    throw lexer.failure(token, "Expected list start, found " + token.toString());
                }

                if (!token.isSymbol(ThriftToken.kListStart)) {
                    throw lexer.failure(token, "Expected list start, found " + token.toString());
                }
                token = lexer.expect("list item or end");
                while (!token.isSymbol(ThriftToken.kListEnd)) {
                    set.add(parseTypedValue(token, lexer, itemType, false));
                    token = lexer.expect("list item, sep or end");
                    if (token.isSymbol(ThriftToken.kLineSep1) || token.isSymbol(ThriftToken.kLineSep2)) {
                        token = lexer.expect("list item or end");
                    }
                }

                return set.build();
            }
            case MAP: {
                PDescriptor itemType = ((PMap<?, ?>) valueType).itemDescriptor();
                PDescriptor keyType = ((PMap<?, ?>) valueType).keyDescriptor();

                UnmodifiableMap.Builder<Object, Object> map = UnmodifiableMap.builder();

                if (!token.isSymbol(ThriftToken.kMessageStart)) {
                    throw lexer.failure(token, "Expected map start, found " + token.toString());
                }

                token = lexer.expect("map key or end");
                while (!token.isSymbol(ThriftToken.kMessageEnd)) {
                    Object key;
                    if (token.type() == ThriftTokenType.STRING) {
                        key = parsePrimitiveKey(token.decodeString(true), token, lexer, keyType);
                    } else if (token.isIdentifier() || token.isQualifiedIdentifier()) {
                        key = registry.getConstantValue(parseType(programName, token.toString())).orElse(null);
                        if (key == null) {
                            if (keyType.getType().equals(PType.STRING) ||
                                keyType.getType().equals(PType.BINARY)) {
                                throw lexer.failure(token, "Expected string literal for string key");
                            }
                            key = parsePrimitiveKey(token.toString(), token, lexer, keyType);
                        }
                    } else {
                        if (keyType.getType().equals(PType.STRING) ||
                            keyType.getType().equals(PType.BINARY)) {
                            throw lexer.failure(token, "Expected string literal for string key");
                        }
                        key = parsePrimitiveKey(token.toString(), token, lexer, keyType);
                    }
                    lexer.expectSymbol("map KV separator", ThriftToken.kKeyValueSep);
                    map.put(key, parseTypedValue(lexer.expect("map value"), lexer, itemType, false));

                    token = lexer.expect("map key, sep or end");
                    if (token.isSymbol(ThriftToken.kLineSep1) || token.isSymbol(ThriftToken.kLineSep2)) {
                        token = lexer.expect("map key or end");
                    }
                }

                return map.build();
            }
            default:
                throw new IllegalArgumentException("Unhandled item type " + valueType.getQualifiedName());
        }
    }

    private Object parsePrimitiveKey(String key, ThriftToken token, ThriftLexer tokenizer, PDescriptor keyType)
            throws IOException {
        switch (keyType.getType()) {
            case ENUM:
                PEnumBuilder<?> eb = ((PEnumDescriptor<?>) keyType).builder();
                if (Strings.isInteger(key)) {
                    return eb.setById(Integer.parseInt(key))
                             .build();
                } else {
                    if (key.startsWith(keyType.getProgramName() + "." + keyType.getName() + ".")) {
                        // Check for qualified type prefixed identifier ( e.g. program.EnumName.VALUE ).
                        key = key.substring(keyType.getProgramName().length() + keyType.getName().length() + 2);
                    } else if (key.startsWith(keyType.getName() + ".")) {
                        // Check for type prefixed identifier ( e.g. EnumName.VALUE ).
                        key = key.substring(keyType.getName().length() + 1);
                    }
                    return eb.setByName(key)
                             .build();
                }
            case BOOL:
                return Boolean.parseBoolean(key);
            case BYTE:
                if (Strings.isInteger(key)) {
                    return Byte.parseByte(key);
                } else {
                    return (byte) findEnumValue(key, token, tokenizer, "byte");
                }
            case I16:
                if (Strings.isInteger(key)) {
                    return Short.parseShort(key);
                } else {
                    return (short) findEnumValue(key, token, tokenizer, "i16");
                }
            case I32:
                if (Strings.isInteger(key)) {
                    return Integer.parseInt(key);
                } else {
                    return findEnumValue(key, token, tokenizer, "i32");
                }
            case I64:
                if (Strings.isInteger(key)) {
                    return Long.parseLong(key);
                } else {
                    return (long) findEnumValue(key, token, tokenizer, "i64");
                }
            case DOUBLE: {
                try {
                    ByteArrayInputStream bais    = new ByteArrayInputStream(key.getBytes(US_ASCII));
                    JsonTokenizer        tokener = new JsonTokenizer(bais);
                    JsonToken            jt      = tokener.expect("parsing double value");
                    return jt.doubleValue();
                } catch (NumberFormatException | IOException | JsonException e) {
                    throw new ThriftException(token, "Unable to parse double value").initCause(e);
                }
            }
            case STRING:
                return key;
            case BINARY:
                return parseBinary(key);
            default:
                throw new ThriftException("Illegal key type: " + keyType.getType());
        }
    }

    private int findEnumValue(String identifier, ThriftToken token, ThriftLexer tokenizer, String expectedType)
            throws IOException {
        String[] parts = identifier.split("\\.", Byte.MAX_VALUE);
        String typeName;
        String valueName;
        if (parts.length == 3) {
            typeName = parts[0] + "." + parts[1];
            valueName = parts[2];
        } else if (parts.length == 2) {
            typeName = parts[0];
            valueName = parts[1];
        } else {
            throw tokenizer.failure(token, identifier + " is not a valid " + expectedType + " value.");
        }

        try {
            PDeclaredDescriptor descriptor = registry.requireDeclaredType(parseType(programName, typeName));
            if (descriptor instanceof PEnumDescriptor) {
                PEnumDescriptor desc = (PEnumDescriptor) descriptor;
                PEnumValue value = desc.findByName(valueName);
                if (value != null) {
                    return value.asInteger();
                }
            }

            throw tokenizer.failure(token, typeName + " is not an enum.");
        } catch (IllegalArgumentException e) {
            throw tokenizer.failure(token, "No type named " + typeName + ".");
        }
    }

    /**
     * Parse a string into binary format using the same rules as above.
     *
     * @param value The string to decode.
     * @return The decoded byte array.
     */
    private Binary parseBinary(String value) {
        return Binary.fromBase64(value);
    }

}