diff --git a/pkg/ai/amazonbedrock.go b/pkg/ai/amazonbedrock.go index f9cd80fa61..c7868c382b 100644 --- a/pkg/ai/amazonbedrock.go +++ b/pkg/ai/amazonbedrock.go @@ -20,12 +20,8 @@ type AmazonBedRockClient struct { client *bedrockruntime.BedrockRuntime model string temperature float32 -} - -// InvokeModelResponseBody represents the response body structure from the model invocation. -type InvokeModelResponseBody struct { - Completion string `json:"completion"` - Stop_reason string `json:"stop_reason"` + topP float32 + maxTokens int } // Amazon BedRock support region list US East (N. Virginia),US West (Oregon),Asia Pacific (Singapore),Asia Pacific (Tokyo),Europe (Frankfurt) @@ -52,14 +48,22 @@ const ( ModelAnthropicClaudeV2 = "anthropic.claude-v2" ModelAnthropicClaudeV1 = "anthropic.claude-v1" ModelAnthropicClaudeInstantV1 = "anthropic.claude-instant-v1" + ModelA21J2UltraV1 = "ai21.j2-ultra-v1" + ModelA21J2JumboInstruct = "ai21.j2-jumbo-instruct" + ModelAmazonTitanExpressV1 = "amazon.titan-text-express-v1" ) var BEDROCK_MODELS = []string{ ModelAnthropicClaudeV2, ModelAnthropicClaudeV1, ModelAnthropicClaudeInstantV1, + ModelA21J2UltraV1, + ModelA21J2JumboInstruct, + ModelAmazonTitanExpressV1, } +//const TOPP = 0.9 moved to config + // GetModelOrDefault check config model func GetModelOrDefault(model string) string { @@ -109,6 +113,8 @@ func (a *AmazonBedRockClient) Configure(config IAIConfig) error { a.client = bedrockruntime.New(sess) a.model = GetModelOrDefault(config.GetModel()) a.temperature = config.GetTemperature() + a.topP = config.GetTopP() + a.maxTokens = config.GetMaxTokens() return nil } @@ -116,14 +122,37 @@ func (a *AmazonBedRockClient) Configure(config IAIConfig) error { // GetCompletion sends a request to the model for generating completion based on the provided prompt. func (a *AmazonBedRockClient) GetCompletion(ctx context.Context, prompt string) (string, error) { - // Prepare the input data for the model invocation - request := map[string]interface{}{ - "prompt": fmt.Sprintf("\n\nHuman: %s \n\nAssistant:", prompt), - "max_tokens_to_sample": 1024, - "temperature": a.temperature, - "top_p": 0.9, + // Prepare the input data for the model invocation based on the model & the Response Body per model as well. + var request map[string]interface{} + switch a.model { + case ModelAnthropicClaudeV2, ModelAnthropicClaudeV1, ModelAnthropicClaudeInstantV1: + request = map[string]interface{}{ + "prompt": fmt.Sprintf("\n\nHuman: %s \n\nAssistant:", prompt), + "max_tokens_to_sample": a.maxTokens, + "temperature": a.temperature, + "top_p": a.topP, + } + case ModelA21J2UltraV1, ModelA21J2JumboInstruct: + request = map[string]interface{}{ + "prompt": prompt, + "maxTokens": a.maxTokens, + "temperature": a.temperature, + "topP": a.topP, + } + case ModelAmazonTitanExpressV1: + request = map[string]interface{}{ + "inputText": fmt.Sprintf("\n\nUser: %s", prompt), + "textGenerationConfig": map[string]interface{}{ + "maxTokenCount": a.maxTokens, + "temperature": a.temperature, + "topP": a.topP, + }, + } + default: + return "", fmt.Errorf("model %s not supported", a.model) } + body, err := json.Marshal(request) if err != nil { return "", err @@ -142,15 +171,56 @@ func (a *AmazonBedRockClient) GetCompletion(ctx context.Context, prompt string) if err != nil { return "", err } - // Parse the response body - output := &InvokeModelResponseBody{} - err = json.Unmarshal(resp.Body, output) - if err != nil { - return "", err - } - return output.Completion, nil + + // Response type changes as per model + switch a.model { + case ModelAnthropicClaudeV2, ModelAnthropicClaudeV1, ModelAnthropicClaudeInstantV1: + type InvokeModelResponseBody struct { + Completion string `json:"completion"` + Stop_reason string `json:"stop_reason"` + } + output := &InvokeModelResponseBody{} + err = json.Unmarshal(resp.Body, output) + if err != nil { + return "", err + } + return output.Completion, nil + case ModelA21J2UltraV1, ModelA21J2JumboInstruct: + type Data struct { + Text string `json:"text"` + } + type Completion struct { + Data Data `json:"data"` + } + type InvokeModelResponseBody struct { + Completions []Completion `json:"completions"` + } + output := &InvokeModelResponseBody{} + err = json.Unmarshal(resp.Body, output) + if err != nil { + return "", err + } + return output.Completions[0].Data.Text, nil + case ModelAmazonTitanExpressV1: + type Result struct { + TokenCount int `json:"tokenCount"` + OutputText string `json:"outputText"` + CompletionReason string `json:"completionReason"` + } + type InvokeModelResponseBody struct { + InputTextTokenCount int `json:"inputTextTokenCount"` + Results []Result `json:"results"` + } + output := &InvokeModelResponseBody{} + err = json.Unmarshal(resp.Body, output) + if err != nil { + return "", err + } + return output.Results[0].OutputText, nil + default: + return "", fmt.Errorf("model %s not supported", a.model) + } } - // GetName returns the name of the AmazonBedRockClient. func (a *AmazonBedRockClient) GetName() string { return amazonbedrockAIClientName