package rest import ( "bytes" "context" "fmt" "github.com/spf13/viper" "houston/common/metrics" "houston/common/util" "houston/logger" "io" "mime/multipart" "net/http" "net/url" "time" ) type HttpRestClient interface { Post(url string, body bytes.Buffer, headers map[string]string, formData map[string]interface{}) (*http.Response, error) PostWithTimeout(url string, body bytes.Buffer, headers map[string]string, timeout time.Duration, formData map[string]interface{}) (*http.Response, error) PutWithTimeout(url string, body bytes.Buffer, headers map[string]string, timeout time.Duration, formData map[string]interface{}) (*http.Response, error) Get(url string, urlParams map[string]string, headers map[string]string, encodedURL bool) (*http.Response, error) GetWithTimeout(url string, urlParams map[string]string, headers map[string]string, timeout time.Duration, encodedURL bool) (*http.Response, error) } type HttpRestClientImpl struct { DefaultTimeout time.Duration } func NewHttpRestClient() *HttpRestClientImpl { return &HttpRestClientImpl{ DefaultTimeout: viper.GetDuration("DEFAULT_HTTP_REQUEST_TIMEOUT"), } } func (client *HttpRestClientImpl) Post(url string, body bytes.Buffer, headers map[string]string, formData map[string]interface{}) (*http.Response, error) { return client.PostWithTimeout(url, body, headers, client.DefaultTimeout, formData) } func (client *HttpRestClientImpl) PostWithTimeout(url string, body bytes.Buffer, headers map[string]string, timeout time.Duration, formData map[string]interface{}) (*http.Response, error) { requestContext, cancelFunc := context.WithTimeout(context.Background(), timeout) defer cancelFunc() contentType := util.ContentTypeJSON restClient := &http.Client{ Timeout: timeout, } if formData != nil { var err error body, contentType, err = constructFormFields(formData) if err != nil { return nil, err } } var request, _ = http.NewRequest("POST", url, &body) request.WithContext(requestContext) request.Header.Set("Content-Type", contentType) for key, value := range headers { request.Header.Set(key, value) } return client.makeHttpRequest(restClient, request) } func (client *HttpRestClientImpl) PutWithTimeout(url string, body bytes.Buffer, headers map[string]string, timeout time.Duration, formData map[string]interface{}) (*http.Response, error) { requestContext, cancelFunc := context.WithTimeout(context.Background(), timeout) defer cancelFunc() contentType := util.ContentTypeJSON restClient := &http.Client{ Timeout: timeout, } if formData != nil { var err error body, contentType, err = constructFormFields(formData) if err != nil { return nil, err } } var request, _ = http.NewRequest("PUT", url, &body) request.WithContext(requestContext) request.Header.Set("Content-Type", contentType) for key, value := range headers { request.Header.Set(key, value) } return client.makeHttpRequest(restClient, request) } func (client *HttpRestClientImpl) Get(url string, urlParams map[string]string, headers map[string]string, encodedURL bool) (*http.Response, error) { return client.GetWithTimeout(url, urlParams, headers, client.DefaultTimeout, encodedURL) } func (client *HttpRestClientImpl) GetWithTimeout(url string, urlParams map[string]string, headers map[string]string, timeout time.Duration, encodedURL bool) (*http.Response, error) { requestContext, cancelFunc := context.WithTimeout(context.Background(), timeout) defer cancelFunc() restClient := &http.Client{ Timeout: timeout, } url += urlFromParams(urlParams, encodedURL) var request, _ = http.NewRequest("GET", url, nil) request.WithContext(requestContext) for key, value := range headers { request.Header.Set(key, value) } return client.makeHttpRequest(restClient, request) } func (client *HttpRestClientImpl) makeHttpRequest(restClient *http.Client, request *http.Request) (*http.Response, error) { startTime := time.Now() response, err := restClient.Do(request) duration := float64(time.Since(startTime).Milliseconds()) if response != nil { metrics.PublishHttpServerRequestMetrics(request.URL.String(), request.Method, response.StatusCode, duration) } else { metrics.PublishHttpServerRequestMetrics(request.URL.Path, request.Method, 0, duration) } return response, err } func urlFromParams(urlParams map[string]string, encode bool) string { if len(urlParams) != 0 { if encode { values := url.Values{} for key, value := range urlParams { values.Add(key, value) } return values.Encode() } else { result := "?" for key, value := range urlParams { if result == "" { result = key + "=" + value } else { result += "&" + key + "=" + value } } return result } } return "" } func constructFormFields(formData map[string]interface{}) (bytes.Buffer, string, error) { body := &bytes.Buffer{} writer := multipart.NewWriter(body) // Iterate through the map and add all non file fields to the form for key, value := range formData { switch value.(type) { case string: if key != "filename" { _ = writer.WriteField(key, value.(string)) } } } // Parse the file field and add it to the form to ensure that file is at the end of the form if formData["file"] != nil { file := formData["file"].(io.Reader) fileName := formData["filename"].(string) part, err := writer.CreateFormFile("file", fileName) if err != nil { logger.Error(fmt.Sprintf("error creating form file: %v", err)) return bytes.Buffer{}, "", err } _, err = io.Copy(part, file) if err != nil { logger.Error(fmt.Sprintf("error copying file data: %v", err)) return bytes.Buffer{}, "", err } } err := writer.Close() contentType := writer.FormDataContentType() if err != nil { logger.Error(fmt.Sprintf("error closing multipart writer: %v", err)) return bytes.Buffer{}, "", err } return *body, contentType, nil }