Skip to content

Commit

Permalink
fix: enabled auth add support watsonx backend
Browse files Browse the repository at this point in the history
Signed-off-by: Guangya Liu <[email protected]>
  • Loading branch information
gyliu513 committed Jul 12, 2024
1 parent 34b6de3 commit d8c8584
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 16 deletions.
6 changes: 6 additions & 0 deletions cmd/auth/add.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ var addCmd = &cobra.Command{
if strings.ToLower(backend) == "amazonbedrock" {
_ = cmd.MarkFlagRequired("providerRegion")
}
if strings.ToLower(backend) == "watsonxai" {
_ = cmd.MarkFlagRequired("projectId")
}
},
Run: func(cmd *cobra.Command, args []string) {

Expand Down Expand Up @@ -132,6 +135,7 @@ var addCmd = &cobra.Command{
TopK: topK,
MaxTokens: maxTokens,
OrganizationId: organizationId,
ProjectID: projectId,
}

if providerIndex == -1 {
Expand Down Expand Up @@ -179,4 +183,6 @@ func init() {
addCmd.Flags().StringVarP(&compartmentId, "compartmentId", "k", "", "Compartment ID for generative AI model (only for oci backend)")
// add flag for openai organization
addCmd.Flags().StringVarP(&organizationId, "organizationId", "o", "", "OpenAI or AzureOpenAI Organization ID (only for openai and azureopenai backend)")
// add flag for IBM Watsonx Project ID
addCmd.Flags().StringVarP(&projectId, "projectId", "j", "", "IBM Watsonx Project ID (only for watsonxai backend)")
}
1 change: 1 addition & 0 deletions cmd/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ var (
topK int32
maxTokens int
organizationId string
projectId string
)

var configAI ai.AIConfiguration
Expand Down
8 changes: 7 additions & 1 deletion pkg/ai/iai.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ type IAIConfig interface {
GetProviderId() string
GetCompartmentId() string
GetOrganizationId() string
GetProjectId() string
GetCustomHeaders() []http.Header
}

Expand Down Expand Up @@ -119,6 +120,7 @@ type AIProvider struct {
TopK int32 `mapstructure:"topk" yaml:"topk,omitempty"`
MaxTokens int `mapstructure:"maxtokens" yaml:"maxtokens,omitempty"`
OrganizationId string `mapstructure:"organizationid" yaml:"organizationid,omitempty"`
ProjectID string `mapstructure:"projectid" yaml:"projectid,omitempty"`
CustomHeaders []http.Header `mapstructure:"customHeaders"`
}

Expand Down Expand Up @@ -177,11 +179,15 @@ func (p *AIProvider) GetOrganizationId() string {
return p.OrganizationId
}

func (p *AIProvider) GetProjectId() string {
return p.ProjectID
}

func (p *AIProvider) GetCustomHeaders() []http.Header {
return p.CustomHeaders
}

var passwordlessProviders = []string{"localai", "ollama", "amazonsagemaker", "amazonbedrock", "googlevertexai", "oci", "watsonxai"}
var passwordlessProviders = []string{"localai", "ollama", "amazonsagemaker", "amazonbedrock", "googlevertexai", "oci"}

func NeedPassword(backend string) bool {
for _, b := range passwordlessProviders {
Expand Down
4 changes: 4 additions & 0 deletions pkg/ai/openai_header_transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ func (m *mockConfig) GetOrganizationId() string {
return ""
}

func (m *mockConfig) GetProjectId() string {
return ""
}

func (m *mockConfig) GetProxyEndpoint() string {
return ""
}
Expand Down
28 changes: 13 additions & 15 deletions pkg/ai/watsonxai.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
package ai

import (
"os"
"fmt"
"context"
"errors"
"fmt"

wx "github.com/IBM/watsonx-go/pkg/models"
)
Expand All @@ -14,20 +13,20 @@ const watsonxAIClientName = "watsonxai"
type WatsonxAIClient struct {
nopCloser

client *wx.Client
model string
temperature float32
topP float32
topK int32
maxNewTokens int
client *wx.Client
model string
temperature float32
topP float32
topK int32
maxNewTokens int
}

const (
modelMetallama = "ibm/granite-13b-chat-v2"
)

func (c *WatsonxAIClient) Configure(config IAIConfig) error {
if(config.GetModel() == "") {
if config.GetModel() == "" {
c.model = config.GetModel()
} else {
c.model = modelMetallama
Expand All @@ -37,20 +36,19 @@ func (c *WatsonxAIClient) Configure(config IAIConfig) error {
c.topK = config.GetTopK()
c.maxNewTokens = config.GetMaxTokens()

// WatsonxAPIKeyEnvVarName = "WATSONX_API_KEY"
// WatsonxProjectIDEnvVarName = "WATSONX_PROJECT_ID"
apiKey, projectID := os.Getenv(wx.WatsonxAPIKeyEnvVarName), os.Getenv(wx.WatsonxProjectIDEnvVarName)

apiKey := config.GetPassword()
if apiKey == "" {
return errors.New("No watsonx API key provided")
}
if projectID == "" {

projectId := config.GetProjectId()
if projectId == "" {
return errors.New("No watsonx project ID provided")
}

client, err := wx.NewClient(
wx.WithWatsonxAPIKey(apiKey),
wx.WithWatsonxProjectID(projectID),
wx.WithWatsonxProjectID(projectId),
)
if err != nil {
return fmt.Errorf("Failed to create client for testing. Error: %v", err)
Expand Down

0 comments on commit d8c8584

Please sign in to comment.