diff --git a/viper.go b/viper.go index 7f54c265a..714b2f2cc 100644 --- a/viper.go +++ b/viper.go @@ -827,10 +827,12 @@ func (v *Viper) isPathShadowedInDeepMap(path []string, m map[string]any) string // "foo.bar.baz" in a lower-priority map func (v *Viper) isPathShadowedInFlatMap(path []string, mi any) string { // unify input map - var m map[string]any - switch mi.(type) { - case map[string]string, map[string]FlagValue: - m = cast.ToStringMap(mi) + var m map[string]interface{} + switch miv := mi.(type) { + case map[string]string: + m = castMapStringToMapInterface(miv) + case map[string]FlagValue: + m = castMapFlagToMapInterface(miv) default: return "" } diff --git a/viper_test.go b/viper_test.go index 2dd67a189..98c379d7c 100644 --- a/viper_test.go +++ b/viper_test.go @@ -2575,6 +2575,51 @@ func TestSliceIndexAccess(t *testing.T) { assert.Equal(t, "Static", v.GetString("tv.0.episodes.1.2")) } +func TestIsPathShadowedInFlatMap(t *testing.T) { + v := New() + + stringMap := map[string]string{ + "foo": "value", + } + + flagMap := map[string]FlagValue{ + "foo": pflagValue{}, + } + + path1 := []string{"foo", "bar"} + expected1 := "foo" + + // "foo.bar" should shadowed by "foo" + assert.Equal(t, expected1, v.isPathShadowedInFlatMap(path1, stringMap)) + assert.Equal(t, expected1, v.isPathShadowedInFlatMap(path1, flagMap)) + + path2 := []string{"bar", "foo"} + expected2 := "" + + // "bar.foo" should not shadowed by "foo" + assert.Equal(t, expected2, v.isPathShadowedInFlatMap(path2, stringMap)) + assert.Equal(t, expected2, v.isPathShadowedInFlatMap(path2, flagMap)) +} + +func TestFlagShadow(t *testing.T) { + v := New() + + v.SetDefault("foo.bar1.bar2", "default") + + flags := pflag.NewFlagSet("test", pflag.ContinueOnError) + flags.String("foo.bar1", "shadowed", "") + flags.VisitAll(func(flag *pflag.Flag) { + flag.Changed = true + }) + + v.BindPFlags(flags) + + assert.Equal(t, "shadowed", v.GetString("foo.bar1")) + // the default "foo.bar1.bar2" value should shadowed by flag "foo.bar1" value + // and should return an empty string + assert.Equal(t, "", v.GetString("foo.bar1.bar2")) +} + func BenchmarkGetBool(b *testing.B) { key := "BenchmarkGetBool" v = New()