1372 lines
40 KiB
Go
1372 lines
40 KiB
Go
// Copyright 2023-2024 Buf Technologies, Inc.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
package protoyaml
|
|
|
|
import (
|
|
"encoding/base64"
|
|
"errors"
|
|
"fmt"
|
|
"math"
|
|
"math/big"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/bufbuild/protovalidate-go"
|
|
"google.golang.org/protobuf/proto"
|
|
"google.golang.org/protobuf/reflect/protoreflect"
|
|
"google.golang.org/protobuf/reflect/protoregistry"
|
|
"google.golang.org/protobuf/types/known/anypb"
|
|
"google.golang.org/protobuf/types/known/durationpb"
|
|
"google.golang.org/protobuf/types/known/structpb"
|
|
"google.golang.org/protobuf/types/known/timestamppb"
|
|
"gopkg.in/yaml.v3"
|
|
)
|
|
|
|
var (
|
|
// We have to initialize this from an init() function below
|
|
// instead of via initializer expression here to avoid the Go
|
|
// compiler complaining about a potential initialization cycle
|
|
// (the initializer expression refers to the function
|
|
// unmarshalAnyMsg, which indirectly refers back to this var).
|
|
wktUnmarshalers map[protoreflect.FullName]customUnmarshaler
|
|
)
|
|
|
|
// Validator is an interface for validating a Protobuf message produced from a given YAML node.
|
|
type Validator interface {
|
|
// Validate the given message.
|
|
Validate(message proto.Message) error
|
|
}
|
|
|
|
// UnmarshalOptions is a configurable YAML format parser for Protobuf messages.
|
|
type UnmarshalOptions struct {
|
|
// The path for the data being unmarshaled.
|
|
//
|
|
// If set, this will be used when producing error messages.
|
|
Path string
|
|
// Validator is a validator to run after unmarshaling a message.
|
|
Validator Validator
|
|
// Resolver is the Protobuf type resolver to use.
|
|
Resolver interface {
|
|
protoregistry.MessageTypeResolver
|
|
protoregistry.ExtensionTypeResolver
|
|
}
|
|
|
|
// If AllowPartial is set, input for messages that will result in missing
|
|
// required fields will not return an error.
|
|
AllowPartial bool
|
|
|
|
// DiscardUnknown specifies whether to discard unknown fields instead of
|
|
// returning an error.
|
|
DiscardUnknown bool
|
|
}
|
|
|
|
// Unmarshal a Protobuf message from the given YAML data.
|
|
func Unmarshal(data []byte, message proto.Message) error {
|
|
return (UnmarshalOptions{}).Unmarshal(data, message)
|
|
}
|
|
|
|
// Unmarshal a Protobuf message from the given YAML data.
|
|
func (o UnmarshalOptions) Unmarshal(data []byte, message proto.Message) error {
|
|
var yamlFile yaml.Node
|
|
if err := yaml.Unmarshal(data, &yamlFile); err != nil {
|
|
return err
|
|
}
|
|
if err := o.unmarshalNode(&yamlFile, message, data); err != nil {
|
|
return err
|
|
}
|
|
if !o.AllowPartial {
|
|
if err := proto.CheckInitialized(message); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// ParseDuration parses a duration string into a durationpb.Duration.
|
|
//
|
|
// Valid time units are "ns", "us" (or "µs"), "ms", "s", "m", "h".
|
|
//
|
|
// This function supports the full range of durationpb.Duration values, including
|
|
// those outside the range of time.Duration.
|
|
func ParseDuration(str string) (*durationpb.Duration, error) {
|
|
// [-+]?([0-9]*(\.[0-9]*)?[a-z]+)+
|
|
neg := false
|
|
|
|
// Consume [-+]?
|
|
if str != "" {
|
|
c := str[0]
|
|
if c == '-' || c == '+' {
|
|
neg = c == '-'
|
|
str = str[1:]
|
|
}
|
|
}
|
|
// Special case: if all that is left is "0", this is zero.
|
|
if str == "0" {
|
|
var empty *durationpb.Duration
|
|
return empty, nil
|
|
}
|
|
if str == "" {
|
|
return nil, errors.New("invalid duration")
|
|
}
|
|
totalNanos := &big.Int{}
|
|
var err error
|
|
for str != "" {
|
|
str, err = parseDurationNext(str, totalNanos)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
if neg {
|
|
totalNanos.Neg(totalNanos)
|
|
}
|
|
result := &durationpb.Duration{}
|
|
quo, rem := totalNanos.QuoRem(totalNanos, nanosPerSecond, &big.Int{})
|
|
if !quo.IsInt64() {
|
|
return nil, errors.New("invalid duration: out of range")
|
|
}
|
|
result.Seconds = quo.Int64()
|
|
result.Nanos = int32(rem.Int64()) //nolint:gosec // not an overflow risk; value is less than 2^30
|
|
return result, nil
|
|
}
|
|
|
|
func (o UnmarshalOptions) unmarshalNode(node *yaml.Node, message proto.Message, data []byte) error {
|
|
if node.Kind == 0 {
|
|
return nil
|
|
}
|
|
unm := &unmarshaler{
|
|
options: o,
|
|
validator: o.Validator,
|
|
lines: strings.Split(string(data), "\n"),
|
|
}
|
|
|
|
// Unwrap the document node
|
|
if node.Kind == yaml.DocumentNode {
|
|
if len(node.Content) != 1 {
|
|
return errors.New("expected exactly one node in document")
|
|
}
|
|
node = node.Content[0]
|
|
}
|
|
|
|
unm.unmarshalMessage(node, message, false)
|
|
if unm.validator != nil {
|
|
err := unm.validator.Validate(message)
|
|
var verr *protovalidate.ValidationError
|
|
switch {
|
|
case err == nil: // Valid.
|
|
case errors.As(err, &verr):
|
|
for _, violation := range verr.Violations {
|
|
closest := unm.nodeClosestToPath(node, message.ProtoReflect().Descriptor(), protovalidate.FieldPathString(violation.Proto.GetField()), violation.Proto.GetForKey())
|
|
unm.addError(closest, &violationError{
|
|
Violation: violation.Proto,
|
|
})
|
|
}
|
|
default:
|
|
unm.addError(node, err)
|
|
}
|
|
}
|
|
|
|
if len(unm.errors) > 0 {
|
|
return unmarshalErrors(unm.errors)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
const atTypeFieldName = "@type"
|
|
|
|
type protoResolver interface {
|
|
protoregistry.MessageTypeResolver
|
|
protoregistry.ExtensionTypeResolver
|
|
}
|
|
|
|
type unmarshaler struct {
|
|
options UnmarshalOptions
|
|
errors []error
|
|
validator Validator
|
|
lines []string
|
|
}
|
|
|
|
func (u *unmarshaler) addError(node *yaml.Node, err error) {
|
|
u.errors = append(u.errors, &nodeError{
|
|
Path: u.options.Path,
|
|
Node: node,
|
|
cause: err,
|
|
line: u.lines[node.Line-1],
|
|
})
|
|
}
|
|
func (u *unmarshaler) addErrorf(node *yaml.Node, format string, args ...interface{}) {
|
|
u.addError(node, fmt.Errorf(format, args...))
|
|
}
|
|
|
|
func (u *unmarshaler) checkKind(node *yaml.Node, expected yaml.Kind) bool {
|
|
if node.Kind != expected {
|
|
u.addErrorf(node, "expected %v, got %v", getNodeKind(expected), getNodeKind(node.Kind))
|
|
return false
|
|
}
|
|
return true
|
|
}
|
|
|
|
func (u *unmarshaler) checkTag(node *yaml.Node, expected string) {
|
|
if node.Tag != "" && node.Tag != expected {
|
|
u.addErrorf(node, "expected tag %v, got %v", expected, node.Tag)
|
|
}
|
|
}
|
|
|
|
func (u *unmarshaler) findAnyTypeURL(node *yaml.Node) string {
|
|
typeURL := ""
|
|
for i := 1; i < len(node.Content); i += 2 {
|
|
keyNode := node.Content[i-1]
|
|
valueNode := node.Content[i]
|
|
if keyNode.Value == atTypeFieldName && u.checkKind(valueNode, yaml.ScalarNode) {
|
|
typeURL = valueNode.Value
|
|
break
|
|
}
|
|
}
|
|
return typeURL
|
|
}
|
|
|
|
func (u *unmarshaler) resolveAnyType(typeURL string) (protoreflect.MessageType, error) {
|
|
// Get the message type.
|
|
msgType, err := u.getResolver().FindMessageByURL(typeURL)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return msgType, nil
|
|
}
|
|
|
|
func (u *unmarshaler) findAnyType(node *yaml.Node) (protoreflect.MessageType, error) {
|
|
typeURL := u.findAnyTypeURL(node)
|
|
if typeURL == "" {
|
|
return nil, errors.New("missing @type field")
|
|
}
|
|
return u.resolveAnyType(typeURL)
|
|
}
|
|
|
|
// Unmarshal the field based on the field kind, ignoring IsList and IsMap,
|
|
// which are handled by the caller.
|
|
func (u *unmarshaler) unmarshalScalar(
|
|
node *yaml.Node,
|
|
field protoreflect.FieldDescriptor,
|
|
forKey bool,
|
|
) (protoreflect.Value, bool) {
|
|
switch field.Kind() {
|
|
case protoreflect.BoolKind:
|
|
return protoreflect.ValueOfBool(u.unmarshalBool(node, forKey)), true
|
|
case protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Sfixed32Kind:
|
|
//nolint:gosec // not overflow risk since unmarshalInteger does range check
|
|
return protoreflect.ValueOfInt32(int32(u.unmarshalInteger(node, 32))), true
|
|
case protoreflect.Int64Kind, protoreflect.Sint64Kind, protoreflect.Sfixed64Kind:
|
|
return protoreflect.ValueOfInt64(u.unmarshalInteger(node, 64)), true
|
|
case protoreflect.Uint32Kind, protoreflect.Fixed32Kind:
|
|
//nolint:gosec // not overflow risk since unmarshalUnsigned does range check
|
|
return protoreflect.ValueOfUint32(uint32(u.unmarshalUnsigned(node, 32))), true
|
|
case protoreflect.Uint64Kind, protoreflect.Fixed64Kind:
|
|
return protoreflect.ValueOfUint64(u.unmarshalUnsigned(node, 64)), true
|
|
case protoreflect.FloatKind:
|
|
return protoreflect.ValueOfFloat32(float32(u.unmarshalFloat(node, 32))), true
|
|
case protoreflect.DoubleKind:
|
|
return protoreflect.ValueOfFloat64(u.unmarshalFloat(node, 64)), true
|
|
case protoreflect.StringKind:
|
|
u.checkKind(node, yaml.ScalarNode)
|
|
return protoreflect.ValueOfString(node.Value), true
|
|
case protoreflect.BytesKind:
|
|
return protoreflect.ValueOfBytes(u.unmarshalBytes(node)), true
|
|
case protoreflect.EnumKind:
|
|
return protoreflect.ValueOfEnum(u.unmarshalEnum(node, field)), true
|
|
default:
|
|
u.addErrorf(node, "unimplemented scalar type %v", field.Kind())
|
|
return protoreflect.Value{}, false
|
|
}
|
|
}
|
|
|
|
// Base64 decodes the given node value.
|
|
func (u *unmarshaler) unmarshalBytes(node *yaml.Node) []byte {
|
|
if !u.checkKind(node, yaml.ScalarNode) {
|
|
return nil
|
|
}
|
|
|
|
enc := base64.StdEncoding
|
|
if strings.ContainsAny(node.Value, "-_") {
|
|
enc = base64.URLEncoding
|
|
}
|
|
if len(node.Value)%4 != 0 {
|
|
enc = enc.WithPadding(base64.NoPadding)
|
|
}
|
|
|
|
// base64 decode the value.
|
|
data, err := enc.DecodeString(node.Value)
|
|
if err != nil {
|
|
u.addErrorf(node, "invalid base64: %v", err)
|
|
}
|
|
return data
|
|
}
|
|
|
|
// Unmarshal raw `true` or `false` values, only allowing for strings for keys.
|
|
func (u *unmarshaler) unmarshalBool(node *yaml.Node, forKey bool) bool {
|
|
if u.checkKind(node, yaml.ScalarNode) {
|
|
switch node.Value {
|
|
case "true":
|
|
if !forKey {
|
|
u.checkTag(node, "!!bool")
|
|
}
|
|
return true
|
|
case "false":
|
|
if !forKey {
|
|
u.checkTag(node, "!!bool")
|
|
}
|
|
return false
|
|
default:
|
|
u.addErrorf(node, "expected bool, got %#v", node.Value)
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
// Unmarshal the given node into an enum value.
|
|
//
|
|
// Accepts either the enum name or number.
|
|
func (u *unmarshaler) unmarshalEnum(node *yaml.Node, field protoreflect.FieldDescriptor) protoreflect.EnumNumber {
|
|
u.checkKind(node, yaml.ScalarNode)
|
|
// Get the enum descriptor.
|
|
enumDesc := field.Enum()
|
|
if enumDesc.FullName() == "google.protobuf.NullValue" {
|
|
return 0
|
|
}
|
|
// Get the enum value.
|
|
enumVal := enumDesc.Values().ByName(protoreflect.Name(node.Value))
|
|
if enumVal == nil {
|
|
lit, err := parseIntLiteral(node.Value)
|
|
if err != nil {
|
|
u.addErrorf(node, "unknown enum value %#v, expected one of %v", node.Value,
|
|
getEnumValueNames(enumDesc.Values()))
|
|
return 0
|
|
} else if err := lit.checkI32(field); err != nil {
|
|
u.addErrorf(node, "%w, expected one of %v", err,
|
|
getEnumValueNames(enumDesc.Values()))
|
|
return 0
|
|
}
|
|
//nolint:gosec // not overflow risk since list.checkI32 call above does range check
|
|
num := protoreflect.EnumNumber(lit.value)
|
|
if lit.negative {
|
|
num = -num
|
|
}
|
|
return num
|
|
}
|
|
return enumVal.Number()
|
|
}
|
|
|
|
// Unmarshal the given node into a float with the given bits.
|
|
func (u *unmarshaler) unmarshalFloat(node *yaml.Node, bits int) float64 {
|
|
if !u.checkKind(node, yaml.ScalarNode) {
|
|
return 0
|
|
}
|
|
|
|
parsed, err := strconv.ParseFloat(node.Value, bits)
|
|
if err != nil {
|
|
u.addErrorf(node, "invalid float: %v", err)
|
|
}
|
|
return parsed
|
|
}
|
|
|
|
// Unmarshal the given node into an unsigned integer with the given bits.
|
|
func (u *unmarshaler) unmarshalUnsigned(node *yaml.Node, bits int) uint64 {
|
|
if !u.checkKind(node, yaml.ScalarNode) {
|
|
return 0
|
|
}
|
|
|
|
parsed, err := parseUintLiteral(node.Value)
|
|
if err != nil {
|
|
u.addErrorf(node, "invalid integer: %v", err)
|
|
}
|
|
if bits < 64 && parsed >= 1<<bits {
|
|
u.addErrorf(node, "integer is too large: > %v", 1<<bits-1)
|
|
}
|
|
return parsed
|
|
}
|
|
|
|
// Unmarshal the given node into a signed integer with the given bits.
|
|
func (u *unmarshaler) unmarshalInteger(node *yaml.Node, bits int) int64 {
|
|
if !u.checkKind(node, yaml.ScalarNode) {
|
|
return 0
|
|
}
|
|
|
|
lit, err := parseIntLiteral(node.Value)
|
|
if err != nil {
|
|
u.addErrorf(node, "invalid integer: %v", err)
|
|
}
|
|
if lit.negative {
|
|
if lit.value <= 1<<(bits-1) {
|
|
//nolint:gosec // we just checked on previous line so not overflow risk
|
|
return -int64(lit.value)
|
|
}
|
|
u.addErrorf(node, "integer is too small: < %v", -(1 << (bits - 1)))
|
|
} else if lit.value >= 1<<(bits-1) {
|
|
u.addErrorf(node, "integer is too large: > %v", 1<<(bits-1)-1)
|
|
}
|
|
//nolint:gosec // we just checked above so not overflow risk
|
|
return int64(lit.value)
|
|
}
|
|
|
|
func getFieldNames(fields protoreflect.FieldDescriptors) []protoreflect.Name {
|
|
names := make([]protoreflect.Name, 0, fields.Len())
|
|
for i := 0; i < fields.Len(); i++ {
|
|
names = append(names, fields.Get(i).Name())
|
|
if i > 5 {
|
|
names = append(names, "...")
|
|
break
|
|
}
|
|
}
|
|
return names
|
|
}
|
|
|
|
func getEnumValueNames(values protoreflect.EnumValueDescriptors) []protoreflect.Name {
|
|
names := make([]protoreflect.Name, 0, values.Len())
|
|
for i := 0; i < values.Len(); i++ {
|
|
names = append(names, values.Get(i).Name())
|
|
if i > 5 {
|
|
names = append(names, "...")
|
|
break
|
|
}
|
|
}
|
|
return names
|
|
}
|
|
|
|
func getNodeKind(kind yaml.Kind) string {
|
|
switch kind {
|
|
case yaml.DocumentNode:
|
|
return "document"
|
|
case yaml.SequenceNode:
|
|
return "sequence"
|
|
case yaml.MappingNode:
|
|
return "mapping"
|
|
case yaml.ScalarNode:
|
|
return "scalar"
|
|
case yaml.AliasNode:
|
|
return "alias"
|
|
}
|
|
return fmt.Sprintf("unknown(%d)", kind)
|
|
}
|
|
|
|
// Parses Octal, Hex, Binary, Decimal, and Unsigned Integer Float literals.
|
|
//
|
|
// Conversion through JSON/YAML may have converted integers into floats, including
|
|
// exponential notation. This function will parse those values back into integers
|
|
// if possible.
|
|
func parseUintLiteral(value string) (uint64, error) {
|
|
base := 10
|
|
if len(value) >= 2 && strings.HasPrefix(value, "0") {
|
|
switch value[1] {
|
|
case 'x', 'X':
|
|
base = 16
|
|
value = value[2:]
|
|
case 'o', 'O':
|
|
base = 8
|
|
value = value[2:]
|
|
case 'b', 'B':
|
|
base = 2
|
|
value = value[2:]
|
|
}
|
|
}
|
|
|
|
parsed, err := strconv.ParseUint(value, base, 64)
|
|
if err != nil {
|
|
parsedFloat, floatErr := strconv.ParseFloat(value, 64)
|
|
if floatErr != nil || parsedFloat < 0 || math.IsInf(parsedFloat, 0) || math.IsNaN(parsedFloat) {
|
|
return 0, err
|
|
}
|
|
// See if it's actually an integer.
|
|
parsed = uint64(parsedFloat)
|
|
if float64(parsed) != parsedFloat || parsed >= (1<<53) {
|
|
return parsed, errors.New("precision loss")
|
|
}
|
|
}
|
|
return parsed, nil
|
|
}
|
|
|
|
type intLit struct {
|
|
negative bool
|
|
value uint64
|
|
}
|
|
|
|
func (lit intLit) checkI32(field protoreflect.FieldDescriptor) error {
|
|
switch {
|
|
case lit.negative && lit.value > 1<<31: // Underflow.
|
|
return fmt.Errorf("expected int32 for %v, got int64", field.FullName())
|
|
case !lit.negative && lit.value >= 1<<31: // Overflow.
|
|
return fmt.Errorf("expected int32 for %v, got int64", field.FullName())
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func parseIntLiteral(value string) (intLit, error) {
|
|
var lit intLit
|
|
if strings.HasPrefix(value, "-") {
|
|
lit.negative = true
|
|
value = value[1:]
|
|
}
|
|
var err error
|
|
lit.value, err = parseUintLiteral(value)
|
|
return lit, err
|
|
}
|
|
|
|
func (u *unmarshaler) getResolver() protoResolver {
|
|
if u.options.Resolver != nil {
|
|
return u.options.Resolver
|
|
}
|
|
return protoregistry.GlobalTypes
|
|
}
|
|
|
|
// findField searches for the field with the given 'key' by extension type, JSONName, TextName,
|
|
// and finally by Number.
|
|
func (u *unmarshaler) findField(key string, msgDesc protoreflect.MessageDescriptor) (protoreflect.FieldDescriptor, error) {
|
|
fields := msgDesc.Fields()
|
|
if strings.HasPrefix(key, "[") && strings.HasSuffix(key, "]") {
|
|
extName := protoreflect.FullName(key[1 : len(key)-1])
|
|
extType, err := u.getResolver().FindExtensionByName(extName)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
result := extType.TypeDescriptor()
|
|
if !msgDesc.ExtensionRanges().Has(result.Number()) || result.ContainingMessage().FullName() != msgDesc.FullName() {
|
|
return nil, fmt.Errorf("message %v cannot be extended by %v", msgDesc.FullName(), result.FullName())
|
|
}
|
|
return result, nil
|
|
}
|
|
if field := fields.ByJSONName(key); field != nil {
|
|
return field, nil
|
|
}
|
|
if field := fields.ByTextName(key); field != nil {
|
|
return field, nil
|
|
}
|
|
num, err := strconv.ParseInt(key, 10, 32)
|
|
if err == nil && num > 0 && num <= math.MaxInt32 {
|
|
//nolint:gosec // we just checked on previous line so not overflow risk
|
|
if field := fields.ByNumber(protoreflect.FieldNumber(num)); field != nil {
|
|
return field, nil
|
|
}
|
|
}
|
|
return nil, protoregistry.NotFound
|
|
}
|
|
|
|
// Unmarshal a field, handling isList/isMap.
|
|
func (u *unmarshaler) unmarshalField(node *yaml.Node, field protoreflect.FieldDescriptor, message proto.Message) {
|
|
if oneofDesc := field.ContainingOneof(); oneofDesc != nil && !oneofDesc.IsSynthetic() {
|
|
// Check if another field in the oneof is already set.
|
|
if whichOne := message.ProtoReflect().WhichOneof(oneofDesc); whichOne != nil {
|
|
u.addErrorf(node, "field %v is already set for oneof %v", whichOne.Name(), oneofDesc.Name())
|
|
return
|
|
}
|
|
}
|
|
|
|
switch {
|
|
case field.IsList():
|
|
u.unmarshalList(node, field, message.ProtoReflect().Mutable(field).List())
|
|
case field.IsMap():
|
|
u.unmarshalMap(node, field, message.ProtoReflect().Mutable(field).Map())
|
|
case field.Message() != nil:
|
|
u.unmarshalMessage(node, message.ProtoReflect().Mutable(field).Message().Interface(), false)
|
|
default:
|
|
if val, ok := u.unmarshalScalar(node, field, false); ok {
|
|
message.ProtoReflect().Set(field, val)
|
|
}
|
|
}
|
|
}
|
|
|
|
// Unmarshal the list, with explicit handling for lists of messages.
|
|
func (u *unmarshaler) unmarshalList(node *yaml.Node, field protoreflect.FieldDescriptor, list protoreflect.List) {
|
|
if u.checkKind(node, yaml.SequenceNode) {
|
|
switch field.Kind() {
|
|
case protoreflect.MessageKind, protoreflect.GroupKind:
|
|
for _, itemNode := range node.Content {
|
|
msgVal := list.NewElement()
|
|
u.unmarshalMessage(itemNode, msgVal.Message().Interface(), false)
|
|
list.Append(msgVal)
|
|
}
|
|
default:
|
|
for _, itemNode := range node.Content {
|
|
val, ok := u.unmarshalScalar(itemNode, field, false)
|
|
if !ok {
|
|
continue
|
|
}
|
|
list.Append(val)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Unmarshal the map, with explicit handling for maps to messages.
|
|
func (u *unmarshaler) unmarshalMap(node *yaml.Node, field protoreflect.FieldDescriptor, mapVal protoreflect.Map) {
|
|
if !u.checkKind(node, yaml.MappingNode) {
|
|
return
|
|
}
|
|
mapKeyField := field.MapKey()
|
|
mapValueField := field.MapValue()
|
|
for i := 1; i < len(node.Content); i += 2 {
|
|
keyNode := node.Content[i-1]
|
|
valueNode := node.Content[i]
|
|
mapKey, ok := u.unmarshalScalar(keyNode, mapKeyField, true)
|
|
if !ok {
|
|
continue
|
|
}
|
|
switch mapValueField.Kind() {
|
|
case protoreflect.MessageKind, protoreflect.GroupKind:
|
|
mapValue := mapVal.NewValue()
|
|
u.unmarshalMessage(valueNode, mapValue.Message().Interface(), false)
|
|
mapVal.Set(mapKey.MapKey(), mapValue)
|
|
default:
|
|
val, ok := u.unmarshalScalar(valueNode, mapValueField, false)
|
|
if !ok {
|
|
continue
|
|
}
|
|
mapVal.Set(mapKey.MapKey(), val)
|
|
}
|
|
}
|
|
}
|
|
|
|
func isNull(node *yaml.Node) bool {
|
|
return node.Tag == "!!null"
|
|
}
|
|
|
|
// Resolve the node to be used with the custom unmarshaler. Returns nil if the
|
|
// there was an error.
|
|
func (u *unmarshaler) findNodeForCustom(node *yaml.Node, forAny bool) *yaml.Node {
|
|
if !forAny {
|
|
return node
|
|
}
|
|
if !u.checkKind(node, yaml.MappingNode) {
|
|
return nil
|
|
}
|
|
var valueNode *yaml.Node
|
|
for i := 1; i < len(node.Content); i += 2 {
|
|
keyNode := node.Content[i-1]
|
|
switch keyNode.Value {
|
|
case "value":
|
|
valueNode = node.Content[i]
|
|
case atTypeFieldName:
|
|
continue // Skip the @type field for Any messages
|
|
default:
|
|
u.addErrorf(keyNode, "unknown field %#v, expended one of %v", keyNode.Value, []string{"value", atTypeFieldName})
|
|
return nil
|
|
}
|
|
}
|
|
if valueNode == nil {
|
|
u.addErrorf(node, "missing \"value\" field")
|
|
}
|
|
return valueNode
|
|
}
|
|
|
|
// Unmarshal the given yaml node into the given proto.Message.
|
|
func (u *unmarshaler) unmarshalMessage(node *yaml.Node, message proto.Message, forAny bool) {
|
|
// Check for a custom unmarshaler
|
|
if custom, ok := wktUnmarshalers[message.ProtoReflect().Descriptor().FullName()]; ok {
|
|
valueNode := u.findNodeForCustom(node, forAny)
|
|
if valueNode == nil {
|
|
return // Error already added.
|
|
} else if custom(u, valueNode, message) {
|
|
return // Custom unmarshaler handled the decoding.
|
|
}
|
|
}
|
|
if isNull(node) {
|
|
return // Null is always allowed for messages
|
|
}
|
|
if node.Kind != yaml.MappingNode {
|
|
u.addErrorf(node, "expected fields for %v, got %v",
|
|
message.ProtoReflect().Descriptor().FullName(), getNodeKind(node.Kind))
|
|
return
|
|
}
|
|
u.unmarshalMessageFields(node, message, forAny)
|
|
}
|
|
|
|
func (u *unmarshaler) unmarshalMessageFields(node *yaml.Node, message proto.Message, forAny bool) {
|
|
// Decode the fields
|
|
msgDesc := message.ProtoReflect().Descriptor()
|
|
for i := 0; i < len(node.Content); i += 2 {
|
|
keyNode := node.Content[i]
|
|
var key string
|
|
switch keyNode.Kind {
|
|
case yaml.ScalarNode:
|
|
key = keyNode.Value
|
|
case yaml.SequenceNode:
|
|
// Interpret single element sequences as extension field.
|
|
if len(keyNode.Content) == 1 && keyNode.Content[0].Kind == yaml.ScalarNode {
|
|
key = "[" + keyNode.Content[0].Value + "]"
|
|
break
|
|
}
|
|
fallthrough
|
|
default:
|
|
// Report an error for non-scalar keys (or sequences with multiple elements).
|
|
u.checkKind(keyNode, yaml.ScalarNode) // Always returns false.
|
|
continue
|
|
}
|
|
|
|
if forAny && key == atTypeFieldName {
|
|
continue // Skip the @type field for Any messages
|
|
}
|
|
field, err := u.findField(key, msgDesc)
|
|
switch {
|
|
case errors.Is(err, protoregistry.NotFound):
|
|
if !u.options.DiscardUnknown {
|
|
u.addErrorf(keyNode, "unknown field %#v, expected one of %v", key, getFieldNames(msgDesc.Fields()))
|
|
}
|
|
case err != nil:
|
|
u.addError(keyNode, err)
|
|
default:
|
|
valueNode := node.Content[i+1]
|
|
u.unmarshalField(valueNode, field, message)
|
|
}
|
|
}
|
|
}
|
|
|
|
type customUnmarshaler func(u *unmarshaler, node *yaml.Node, message proto.Message) bool
|
|
|
|
func unmarshalAnyMsg(unm *unmarshaler, node *yaml.Node, message proto.Message) bool {
|
|
if node.Kind != yaml.MappingNode || len(node.Content) == 0 {
|
|
return false
|
|
}
|
|
anyVal, ok := message.(*anypb.Any)
|
|
if !ok {
|
|
anyVal = &anypb.Any{}
|
|
}
|
|
|
|
// Get the message type.
|
|
msgType, err := unm.findAnyType(node)
|
|
if err != nil {
|
|
unm.addError(node, err)
|
|
return true
|
|
}
|
|
|
|
protoVal := msgType.New()
|
|
unm.unmarshalMessage(node, protoVal.Interface(), true)
|
|
if err = anyVal.MarshalFrom(protoVal.Interface()); err != nil {
|
|
unm.addErrorf(node, "failed to marshal %v: %v", msgType.Descriptor().FullName(), err)
|
|
}
|
|
|
|
if !ok {
|
|
return setFieldByName(message, "type_url", protoreflect.ValueOfString(anyVal.GetTypeUrl())) &&
|
|
setFieldByName(message, "value", protoreflect.ValueOfBytes(anyVal.GetValue()))
|
|
}
|
|
|
|
return true
|
|
}
|
|
|
|
const (
|
|
maxTimestampSeconds = 253402300799
|
|
minTimestampSeconds = -62135596800
|
|
)
|
|
|
|
// Format is RFC3339Nano, limited to the range 0001-01-01T00:00:00Z to
|
|
// 9999-12-31T23:59:59Z inclusive.
|
|
func parseTimestamp(txt string, timestamp *timestamppb.Timestamp) error {
|
|
parsed, err := time.Parse(time.RFC3339Nano, txt)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
// Validate seconds.
|
|
secs := parsed.Unix()
|
|
if secs < minTimestampSeconds {
|
|
return errors.New("before 0001-01-01T00:00:00Z")
|
|
} else if secs > maxTimestampSeconds {
|
|
return errors.New("after 9999-12-31T23:59:59Z")
|
|
}
|
|
// Validate nanos.
|
|
subsecond := strings.LastIndexByte(txt, '.')
|
|
timezone := strings.LastIndexAny(txt, "Z-+")
|
|
if subsecond >= 0 && timezone >= subsecond && timezone-subsecond > len(".999999999") {
|
|
return errors.New("too many fractional second digits")
|
|
}
|
|
|
|
timestamp.Seconds = secs
|
|
timestamp.Nanos = int32(parsed.Nanosecond()) //nolint:gosec // not an overflow risk; value is less than 2^30
|
|
return nil
|
|
}
|
|
|
|
func setFieldByName(message proto.Message, name string, value protoreflect.Value) bool {
|
|
field := message.ProtoReflect().Descriptor().Fields().ByName(protoreflect.Name(name))
|
|
if field == nil {
|
|
return false
|
|
}
|
|
message.ProtoReflect().Set(field, value)
|
|
return true
|
|
}
|
|
|
|
func unmarshalDurationMsg(unm *unmarshaler, node *yaml.Node, message proto.Message) bool {
|
|
if node.Kind != yaml.ScalarNode || len(node.Value) == 0 || isNull(node) {
|
|
return false
|
|
}
|
|
duration, err := ParseDuration(node.Value)
|
|
if err != nil {
|
|
unm.addError(node, err)
|
|
return true
|
|
}
|
|
|
|
if value, ok := message.(*durationpb.Duration); ok {
|
|
value.Seconds = duration.GetSeconds()
|
|
value.Nanos = duration.GetNanos()
|
|
return true
|
|
}
|
|
|
|
// Set the fields dynamically.
|
|
return setFieldByName(message, "seconds", protoreflect.ValueOfInt64(duration.GetSeconds())) &&
|
|
setFieldByName(message, "nanos", protoreflect.ValueOfInt32(duration.GetNanos()))
|
|
}
|
|
|
|
func unmarshalTimestampMsg(unm *unmarshaler, node *yaml.Node, message proto.Message) bool {
|
|
if node.Kind != yaml.ScalarNode || len(node.Value) == 0 || isNull(node) {
|
|
return false
|
|
}
|
|
timestamp, ok := message.(*timestamppb.Timestamp)
|
|
if !ok {
|
|
timestamp = ×tamppb.Timestamp{}
|
|
}
|
|
err := parseTimestamp(node.Value, timestamp)
|
|
if err != nil {
|
|
unm.addErrorf(node, "invalid timestamp: %v", err)
|
|
} else if !ok {
|
|
return setFieldByName(message, "seconds", protoreflect.ValueOfInt64(timestamp.GetSeconds())) &&
|
|
setFieldByName(message, "nanos", protoreflect.ValueOfInt32(timestamp.GetNanos()))
|
|
}
|
|
return true
|
|
}
|
|
|
|
// Forwards unmarshaling to the "value" field of the given wrapper message.
|
|
func unmarshalWrapperMsg(unm *unmarshaler, node *yaml.Node, message proto.Message) bool {
|
|
valueField := message.ProtoReflect().Descriptor().Fields().ByName("value")
|
|
if node.Kind == yaml.MappingNode || valueField == nil {
|
|
return false
|
|
}
|
|
unm.unmarshalField(node, valueField, message)
|
|
return true
|
|
}
|
|
|
|
func dynSetValue(message proto.Message, value *structpb.Value) bool {
|
|
switch val := value.GetKind().(type) {
|
|
case *structpb.Value_NullValue:
|
|
return setFieldByName(message, "null_value", protoreflect.ValueOfEnum(protoreflect.EnumNumber(val.NullValue)))
|
|
case *structpb.Value_NumberValue:
|
|
return setFieldByName(message, "number_value", protoreflect.ValueOfFloat64(val.NumberValue))
|
|
case *structpb.Value_StringValue:
|
|
return setFieldByName(message, "string_value", protoreflect.ValueOfString(val.StringValue))
|
|
case *structpb.Value_BoolValue:
|
|
return setFieldByName(message, "bool_value", protoreflect.ValueOfBool(val.BoolValue))
|
|
case *structpb.Value_ListValue:
|
|
listFld := message.ProtoReflect().Descriptor().Fields().ByName("list_value")
|
|
if listFld == nil {
|
|
return false
|
|
}
|
|
listVal := message.ProtoReflect().Mutable(listFld).Message().Interface()
|
|
return dynSetListValue(listVal, val.ListValue)
|
|
case *structpb.Value_StructValue:
|
|
structFld := message.ProtoReflect().Descriptor().Fields().ByName("struct_value")
|
|
if structFld == nil {
|
|
return false
|
|
}
|
|
structVal := message.ProtoReflect().Mutable(structFld).Message().Interface()
|
|
return dynSetStruct(structVal, val.StructValue)
|
|
}
|
|
return false
|
|
}
|
|
|
|
func dynSetListValue(message proto.Message, list *structpb.ListValue) bool {
|
|
valuesFld := message.ProtoReflect().Descriptor().Fields().ByName("values")
|
|
if valuesFld == nil {
|
|
return false
|
|
}
|
|
values := message.ProtoReflect().Mutable(valuesFld).List()
|
|
for _, item := range list.GetValues() {
|
|
value := values.NewElement()
|
|
if !dynSetValue(value.Message().Interface(), item) {
|
|
return false
|
|
}
|
|
values.Append(value)
|
|
}
|
|
return true
|
|
}
|
|
|
|
func dynSetStruct(message proto.Message, structVal *structpb.Struct) bool {
|
|
fieldsFld := message.ProtoReflect().Descriptor().Fields().ByName("fields")
|
|
if fieldsFld == nil {
|
|
return false
|
|
}
|
|
fields := message.ProtoReflect().Mutable(fieldsFld).Map()
|
|
for key, item := range structVal.GetFields() {
|
|
value := fields.NewValue()
|
|
if !dynSetValue(value.Message().Interface(), item) {
|
|
return false
|
|
}
|
|
fields.Set(protoreflect.ValueOfString(key).MapKey(), value)
|
|
}
|
|
return true
|
|
}
|
|
|
|
func unmarshalValueMsg(unm *unmarshaler, node *yaml.Node, message proto.Message) bool {
|
|
value, ok := message.(*structpb.Value)
|
|
if !ok {
|
|
value = &structpb.Value{}
|
|
}
|
|
unm.unmarshalValue(node, value)
|
|
if !ok {
|
|
return dynSetValue(message, value)
|
|
}
|
|
return true
|
|
}
|
|
|
|
func unmarshalListValueMsg(unm *unmarshaler, node *yaml.Node, message proto.Message) bool {
|
|
if node.Kind != yaml.SequenceNode {
|
|
return false
|
|
}
|
|
listValue, ok := message.(*structpb.ListValue)
|
|
if !ok {
|
|
listValue = &structpb.ListValue{}
|
|
}
|
|
unm.unmarshalListValue(node, listValue)
|
|
if !ok {
|
|
return dynSetListValue(message, listValue)
|
|
}
|
|
return true
|
|
}
|
|
|
|
func unmarshalStructMsg(unm *unmarshaler, node *yaml.Node, message proto.Message) bool {
|
|
if node.Kind != yaml.MappingNode {
|
|
return false
|
|
}
|
|
structVal, ok := message.(*structpb.Struct)
|
|
if !ok {
|
|
structVal = &structpb.Struct{}
|
|
}
|
|
unm.unmarshalStruct(node, structVal)
|
|
if !ok {
|
|
return dynSetStruct(message, structVal)
|
|
}
|
|
return true
|
|
}
|
|
|
|
// Unmarshal the given yaml node into a structpb.Value, using the given
|
|
// field descriptor to validate the type, if non-nil.
|
|
func (u *unmarshaler) unmarshalValue(
|
|
node *yaml.Node,
|
|
value *structpb.Value,
|
|
) {
|
|
// Unmarshal the value.
|
|
switch node.Kind {
|
|
case yaml.SequenceNode: // A list.
|
|
listValue := &structpb.ListValue{}
|
|
u.unmarshalListValue(node, listValue)
|
|
value.Kind = &structpb.Value_ListValue{ListValue: listValue}
|
|
case yaml.MappingNode: // A message or map.
|
|
structVal := &structpb.Struct{}
|
|
u.unmarshalStruct(node, structVal)
|
|
value.Kind = &structpb.Value_StructValue{StructValue: structVal}
|
|
case yaml.ScalarNode:
|
|
u.unmarshalScalarValue(node, value)
|
|
case 0:
|
|
value.Kind = &structpb.Value_NullValue{}
|
|
default:
|
|
u.addErrorf(node, "unimplemented value kind: %v", getNodeKind(node.Kind))
|
|
}
|
|
}
|
|
|
|
// Unmarshal the given yaml node into a structpb.ListValue, using the given field
|
|
// descriptor to validate each item, if non-nil.
|
|
func (u *unmarshaler) unmarshalListValue(
|
|
node *yaml.Node,
|
|
list *structpb.ListValue,
|
|
) {
|
|
for _, itemNode := range node.Content {
|
|
itemValue := &structpb.Value{}
|
|
u.unmarshalValue(itemNode, itemValue)
|
|
list.Values = append(list.GetValues(), itemValue)
|
|
}
|
|
}
|
|
|
|
// Unmarshal the given yaml node into a structpb.Struct
|
|
//
|
|
// Structs can represent either a message or a map.
|
|
// For messages, the message descriptor can be provided or inferred from the node.
|
|
// For maps, the field descriptor can be provided to validate the map keys/values.
|
|
func (u *unmarshaler) unmarshalStruct(
|
|
node *yaml.Node,
|
|
message *structpb.Struct,
|
|
) {
|
|
for i := 1; i < len(node.Content); i += 2 {
|
|
keyNode := node.Content[i-1]
|
|
// Validate the key.
|
|
if !u.checkKind(keyNode, yaml.ScalarNode) {
|
|
continue
|
|
}
|
|
|
|
// Unmarshal the value.
|
|
valueNode := node.Content[i]
|
|
value := &structpb.Value{}
|
|
u.unmarshalValue(valueNode, value)
|
|
if message.GetFields() == nil {
|
|
message.Fields = make(map[string]*structpb.Value)
|
|
}
|
|
message.Fields[keyNode.Value] = value
|
|
}
|
|
}
|
|
|
|
func (u *unmarshaler) unmarshalScalarValue(node *yaml.Node, value *structpb.Value) {
|
|
switch node.Tag {
|
|
case "!!null":
|
|
value.Kind = &structpb.Value_NullValue{}
|
|
case "!!bool":
|
|
u.unmarshalScalarBool(node, value)
|
|
default:
|
|
u.unmarshalScalarString(node, value)
|
|
}
|
|
}
|
|
|
|
// bool, string, or bytes.
|
|
func (u *unmarshaler) unmarshalScalarBool(node *yaml.Node, value *structpb.Value) {
|
|
switch node.Value {
|
|
case "true":
|
|
value.Kind = &structpb.Value_BoolValue{BoolValue: true}
|
|
case "false":
|
|
value.Kind = &structpb.Value_BoolValue{BoolValue: false}
|
|
default: // This is a string, not a bool.
|
|
value.Kind = &structpb.Value_StringValue{StringValue: node.Value}
|
|
}
|
|
}
|
|
|
|
// Can be string, bytes, float, or int.
|
|
func (u *unmarshaler) unmarshalScalarString(node *yaml.Node, value *structpb.Value) {
|
|
floatVal, err := strconv.ParseFloat(node.Value, 64)
|
|
if err != nil {
|
|
value.Kind = &structpb.Value_StringValue{StringValue: node.Value}
|
|
return
|
|
}
|
|
|
|
if math.IsInf(floatVal, 0) || math.IsNaN(floatVal) {
|
|
// String or float.
|
|
value.Kind = &structpb.Value_StringValue{StringValue: node.Value}
|
|
return
|
|
}
|
|
|
|
// String, float, or int.
|
|
u.unmarshalScalarFloat(node, value, floatVal)
|
|
}
|
|
|
|
func (u *unmarshaler) unmarshalScalarFloat(node *yaml.Node, value *structpb.Value, floatVal float64) {
|
|
// Try to parse it as in integer, to see if the float representation is lossy.
|
|
lit, litErr := parseIntLiteral(node.Value)
|
|
|
|
// Check if we can represent this as a number.
|
|
floatUintVal := uint64(math.Abs(floatVal)) // The uint64 representation of the float.
|
|
if litErr != nil || floatUintVal == lit.value { // Safe to represent as a number.
|
|
value.Kind = &structpb.Value_NumberValue{NumberValue: floatVal}
|
|
} else { // Keep string representation.
|
|
value.Kind = &structpb.Value_StringValue{StringValue: node.Value}
|
|
}
|
|
}
|
|
|
|
// NodeClosestToPath returns the node closest to the given field path.
|
|
//
|
|
// If toKey is true, the key node is returned if the path points to a map entry.
|
|
//
|
|
// Example field paths:
|
|
// - 'foo' -> the field foo
|
|
// - 'foo[0]' -> the first element of the repeated field foo or the map entry with key '0'
|
|
// - 'foo.bar' -> the field bar in the message field foo
|
|
// - 'foo["bar"]' -> the map entry with key 'bar' in the map field foo
|
|
func (u *unmarshaler) nodeClosestToPath(root *yaml.Node, msgDesc protoreflect.MessageDescriptor, path string, toKey bool) *yaml.Node {
|
|
parsedPath, err := parseFieldPath(path)
|
|
if err != nil {
|
|
return root
|
|
}
|
|
return u.findNodeByPath(root, msgDesc, parsedPath, toKey)
|
|
}
|
|
|
|
func parseFieldPath(path string) ([]string, error) {
|
|
if len(path) == 0 {
|
|
return nil, nil
|
|
}
|
|
next, path := parseNextFieldName(path)
|
|
result := []string{next}
|
|
for len(path) > 0 {
|
|
switch path[0] {
|
|
case '[': // Parse array index or map key.
|
|
next, path = parseNextValue(path[1:])
|
|
case '.': // Parse field name.
|
|
next, path = parseNextFieldName(path[1:])
|
|
default:
|
|
return nil, errors.New("invalid path")
|
|
}
|
|
result = append(result, next)
|
|
}
|
|
return result, nil
|
|
}
|
|
|
|
func parseNextFieldName(path string) (string, string) {
|
|
for i := 0; i < len(path); i++ {
|
|
switch path[i] {
|
|
case '.':
|
|
return path[:i], path[i:]
|
|
case '[':
|
|
return path[:i], path[i:]
|
|
}
|
|
}
|
|
return path, ""
|
|
}
|
|
|
|
func parseNextValue(path string) (string, string) {
|
|
if len(path) == 0 {
|
|
return "", ""
|
|
}
|
|
if path[0] == '"' {
|
|
// Parse string.
|
|
for i := 1; i < len(path); i++ {
|
|
if path[i] == '\\' {
|
|
i++ // Skip escaped character.
|
|
} else if path[i] == '"' {
|
|
result, err := strconv.Unquote(path[:i+1])
|
|
if err != nil {
|
|
return "", ""
|
|
}
|
|
return result, path[i+2:]
|
|
}
|
|
}
|
|
return path, ""
|
|
}
|
|
// Go til the trailing ']'
|
|
for i := 0; i < len(path); i++ {
|
|
if path[i] == ']' {
|
|
return path[:i], path[i+1:]
|
|
}
|
|
}
|
|
return path, ""
|
|
}
|
|
|
|
// Returns the node as close to the given path as possible.
|
|
func (u *unmarshaler) findNodeByPath(root *yaml.Node, msgDesc protoreflect.MessageDescriptor, path []string, toKey bool) *yaml.Node {
|
|
cur := root
|
|
curMsg := msgDesc
|
|
var curMap protoreflect.FieldDescriptor
|
|
for i, key := range path {
|
|
switch cur.Kind {
|
|
case yaml.MappingNode:
|
|
if curMsg != nil {
|
|
field, err := u.findField(key, curMsg)
|
|
if err != nil {
|
|
return cur
|
|
}
|
|
var found bool
|
|
cur, found = findNodeByField(cur, field)
|
|
switch {
|
|
case !found:
|
|
return cur
|
|
case field.IsMap():
|
|
curMap = field
|
|
curMsg = nil
|
|
default:
|
|
curMap = nil
|
|
curMsg = field.Message()
|
|
}
|
|
} else if curMap != nil {
|
|
var found bool
|
|
var keyNode *yaml.Node
|
|
keyNode, cur, found = findEntryByKey(cur, key)
|
|
if !found {
|
|
return cur
|
|
}
|
|
if i == len(path)-1 && toKey {
|
|
return keyNode
|
|
}
|
|
curMsg = curMap.MapValue().Message()
|
|
curMap = nil
|
|
}
|
|
case yaml.SequenceNode:
|
|
idx, err := strconv.Atoi(key)
|
|
if err != nil || idx < 0 || idx >= len(cur.Content) {
|
|
return cur
|
|
}
|
|
cur = cur.Content[idx]
|
|
default:
|
|
return cur
|
|
}
|
|
}
|
|
return cur
|
|
}
|
|
|
|
func findNodeByField(cur *yaml.Node, field protoreflect.FieldDescriptor) (*yaml.Node, bool) {
|
|
fieldNum := fmt.Sprintf("%d", field.Number())
|
|
for i := 1; i < len(cur.Content); i += 2 {
|
|
keyNode := cur.Content[i-1]
|
|
if keyNode.Value == string(field.Name()) ||
|
|
keyNode.Value == field.JSONName() ||
|
|
keyNode.Value == fieldNum {
|
|
return cur.Content[i], true
|
|
}
|
|
}
|
|
return cur, false
|
|
}
|
|
|
|
func findEntryByKey(cur *yaml.Node, key string) (*yaml.Node, *yaml.Node, bool) {
|
|
for i := 1; i < len(cur.Content); i += 2 {
|
|
keyNode := cur.Content[i-1]
|
|
if keyNode.Value == key {
|
|
return keyNode, cur.Content[i], true
|
|
}
|
|
}
|
|
return nil, cur, false
|
|
}
|
|
|
|
// nanosPerSecond is the number of nanoseconds in a second.
|
|
var nanosPerSecond = new(big.Int).SetUint64(uint64(time.Second / time.Nanosecond))
|
|
|
|
// nanosMap is a map of time unit names to their duration in nanoseconds.
|
|
var nanosMap = map[string]*big.Int{
|
|
"ns": new(big.Int).SetUint64(1), // Identity for nanos.
|
|
"us": new(big.Int).SetUint64(uint64(time.Microsecond / time.Nanosecond)),
|
|
"µs": new(big.Int).SetUint64(uint64(time.Microsecond / time.Nanosecond)), // U+00B5 = micro symbol
|
|
"μs": new(big.Int).SetUint64(uint64(time.Microsecond / time.Nanosecond)), // U+03BC = Greek letter mu
|
|
"ms": new(big.Int).SetUint64(uint64(time.Millisecond / time.Nanosecond)),
|
|
"s": nanosPerSecond,
|
|
"m": new(big.Int).SetUint64(uint64(time.Minute / time.Nanosecond)),
|
|
"h": new(big.Int).SetUint64(uint64(time.Hour / time.Nanosecond)),
|
|
}
|
|
|
|
// unitsNames is the (normalized) list of time unit names.
|
|
var unitsNames = []string{"h", "m", "s", "ms", "us", "ns"}
|
|
|
|
// parseDurationNest parses a single segment of the duration string.
|
|
func parseDurationNext(str string, totalNanos *big.Int) (string, error) {
|
|
// The next character must be [0-9.]
|
|
if !(str[0] == '.' || '0' <= str[0] && str[0] <= '9') {
|
|
return "", errors.New("invalid duration")
|
|
}
|
|
var err error
|
|
var whole, frac uint64
|
|
var pre bool // Whether we have seen a digit before the dot.
|
|
whole, str, pre, err = leadingInt(str)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
var scale *big.Int
|
|
var post bool // Whether we have seen a digit after the dot.
|
|
if str != "" && str[0] == '.' {
|
|
str = str[1:]
|
|
frac, scale, str, post = leadingFrac(str)
|
|
}
|
|
if !pre && !post {
|
|
return "", errors.New("invalid duration")
|
|
}
|
|
|
|
end := unitEnd(str)
|
|
if end == 0 {
|
|
return "", fmt.Errorf("invalid duration: missing unit, expected one of %v", unitsNames)
|
|
}
|
|
unitName := str[:end]
|
|
str = str[end:]
|
|
nanosPerUnit, ok := nanosMap[unitName]
|
|
if !ok {
|
|
return "", fmt.Errorf("invalid duration: unknown unit, expected one of %v", unitsNames)
|
|
}
|
|
|
|
// Convert to nanos and add to total.
|
|
// totalNanos += whole * nanosPerUnit + frac * nanosPerUnit / scale
|
|
if whole > 0 {
|
|
wholeNanos := &big.Int{}
|
|
wholeNanos.SetUint64(whole)
|
|
wholeNanos.Mul(wholeNanos, nanosPerUnit)
|
|
totalNanos.Add(totalNanos, wholeNanos)
|
|
}
|
|
if frac > 0 {
|
|
fracNanos := &big.Int{}
|
|
fracNanos.SetUint64(frac)
|
|
fracNanos.Mul(fracNanos, nanosPerUnit)
|
|
rem := &big.Int{}
|
|
fracNanos.QuoRem(fracNanos, scale, rem)
|
|
if rem.Uint64() > 0 {
|
|
return "", errors.New("invalid duration: fractional nanos")
|
|
}
|
|
totalNanos.Add(totalNanos, fracNanos)
|
|
}
|
|
return str, nil
|
|
}
|
|
|
|
func unitEnd(str string) int {
|
|
var i int
|
|
for ; i < len(str); i++ {
|
|
c := str[i]
|
|
if c == '.' || c == '-' || '0' <= c && c <= '9' {
|
|
return i
|
|
}
|
|
}
|
|
return i
|
|
}
|
|
|
|
func leadingFrac(str string) (result uint64, scale *big.Int, rem string, post bool) {
|
|
var i int
|
|
scale = big.NewInt(1)
|
|
big10 := big.NewInt(10)
|
|
var overflow bool
|
|
for ; i < len(str); i++ {
|
|
chr := str[i]
|
|
if chr < '0' || chr > '9' {
|
|
break
|
|
}
|
|
if overflow {
|
|
continue
|
|
}
|
|
if result > (1<<63-1)/10 {
|
|
overflow = true
|
|
continue
|
|
}
|
|
temp := result*10 + uint64(chr-'0')
|
|
if temp > 1<<63 {
|
|
overflow = true
|
|
continue
|
|
}
|
|
result = temp
|
|
scale.Mul(scale, big10)
|
|
}
|
|
return result, scale, str[i:], i > 0
|
|
}
|
|
|
|
func leadingInt(str string) (result uint64, rem string, pre bool, err error) {
|
|
var i int
|
|
for ; i < len(str); i++ {
|
|
c := str[i]
|
|
if c < '0' || c > '9' {
|
|
break
|
|
}
|
|
newResult := result*10 + uint64(c-'0')
|
|
if newResult < result {
|
|
return 0, str, i > 0, errors.New("integer overflow")
|
|
}
|
|
result = newResult
|
|
}
|
|
return result, str[i:], i > 0, nil
|
|
}
|
|
|
|
func init() { //nolint:gochecknoinits
|
|
wktUnmarshalers = map[protoreflect.FullName]customUnmarshaler{
|
|
"google.protobuf.Any": unmarshalAnyMsg,
|
|
"google.protobuf.Duration": unmarshalDurationMsg,
|
|
"google.protobuf.Timestamp": unmarshalTimestampMsg,
|
|
"google.protobuf.BoolValue": unmarshalWrapperMsg,
|
|
"google.protobuf.BytesValue": unmarshalWrapperMsg,
|
|
"google.protobuf.DoubleValue": unmarshalWrapperMsg,
|
|
"google.protobuf.FloatValue": unmarshalWrapperMsg,
|
|
"google.protobuf.Int32Value": unmarshalWrapperMsg,
|
|
"google.protobuf.Int64Value": unmarshalWrapperMsg,
|
|
"google.protobuf.UInt32Value": unmarshalWrapperMsg,
|
|
"google.protobuf.UInt64Value": unmarshalWrapperMsg,
|
|
"google.protobuf.StringValue": unmarshalWrapperMsg,
|
|
"google.protobuf.Value": unmarshalValueMsg,
|
|
"google.protobuf.ListValue": unmarshalListValueMsg,
|
|
"google.protobuf.Struct": unmarshalStructMsg,
|
|
}
|
|
}
|