参考文心一言ERNIE-Bot 4.0模型流式和非流式API调用(SpringBoot+OkHttp3+SSE+WebSocket) - autunomy - 博客园 (cnblogs.com)
后端
引入依赖<dependency>
<groupId>com.squareup.okhttp3</groupId>
<artifactId>okhttp</artifactId>
<version>4.9.3</version>
</dependency>
<dependency>
<groupId>com.alibaba</groupId>
<artifactId>fastjson</artifactId>
<version>1.2.79</version>
</dependency>
配置okHttp,创建OkHttpConfiguration.java
import okhttp3.ConnectionPool;
import okhttp3.OkHttpClient;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import javax.net.ssl.*;
import java.security.*;
import java.security.cert.CertificateException;
import java.security.cert.X509Certificate;
import java.util.concurrent.TimeUnit;
@Configuration
public class OkHttpConfiguration {
@Value("${ok.http.connect-timeout}")
private Integer connectTimeout;
@Value("${ok.http.read-timeout}")
private Integer readTimeout;
@Value("${ok.http.write-timeout}")
private Integer writeTimeout;
@Value("${ok.http.max-idle-connections}")
private Integer maxIdleConnections;
@Value("${ok.http.keep-alive-duration}")
private Long keepAliveDuration;
@Bean
public OkHttpClient okHttpClient() {
return new OkHttpClient.Builder()
.sslSocketFactory(sslSocketFactory(), x509TrustManager())
// 是否开启缓存
.retryOnConnectionFailure(false)
.connectionPool(pool())
.connectTimeout(connectTimeout, TimeUnit.SECONDS)
.readTimeout(readTimeout, TimeUnit.SECONDS)
.writeTimeout(writeTimeout,TimeUnit.SECONDS)
.hostnameVerifier((hostname, session) -> true)
// 设置代理
// .proxy(new Proxy(Proxy.Type.HTTP, new InetSocketAddress("127.0.0.1", 8888)))
// 拦截器
// .addInterceptor()
.build();
}
@Bean
public X509TrustManager x509TrustManager() {
return new X509TrustManager() {
@Override
public void checkClientTrusted(X509Certificate[] chain, String authType)
throws CertificateException {
}
@Override
public void checkServerTrusted(X509Certificate[] chain, String authType)
throws CertificateException {
}
@Override
public X509Certificate[] getAcceptedIssuers() {
return new X509Certificate[0];
}
};
}
@Bean
public SSLSocketFactory sslSocketFactory() {
try {
// 信任任何链接
SSLContext sslContext = SSLContext.getInstance("TLS");
sslContext.init(null, new TrustManager[]{x509TrustManager()}, new SecureRandom());
return sslContext.getSocketFactory();
} catch (NoSuchAlgorithmException | KeyManagementException e) {
e.printStackTrace();
}
return null;
}
@Bean
public ConnectionPool pool() {
return new ConnectionPool(maxIdleConnections, keepAliveDuration, TimeUnit.SECONDS);
}
}
配置全局变量,创建WenXinConfig.java
package com.pet.common;
import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONObject;
import lombok.extern.slf4j.Slf4j;
import okhttp3.*;
import okhttp3.MediaType;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Configuration;
import java.io.IOException;
import java.util.Date;
@Configuration
@Slf4j
public class WenXinConfig {
@Value("${wenxin.apiKey}")
public String API_KEY;
@Value("${wenxin.secretKey}")
public String SECRET_KEY;
@Value("${wenxin.accessTokenUrl}")
public String ACCESS_TOKEN_URL;
@Value("${wenxin.ERNIE_Bot.URL}")
public String ERNIE_Bot_4_0_URL;
//过期时间为30天
public String ACCESS_TOKEN = null;
public String REFRESH_TOKEN = null;
public Date CREATE_TIME = null;//accessToken创建时间
public Date EXPIRATION_TIME = null;//accessToken到期时间
/**
* 获取accessToken
* @return true表示成功 false表示失败
*/
public synchronized String flushAccessToken(){
//判断当前AccessToken是否为空且判断是否过期
if(ACCESS_TOKEN != null && EXPIRATION_TIME.getTime() > CREATE_TIME.getTime()) return ACCESS_TOKEN;
//构造请求体 包含请求参数和请求头等信息
RequestBody body = RequestBody.create(MediaType.parse("application/json"),"");
Request request = new Request.Builder()
.url(ACCESS_TOKEN_URL+"?client_id="+API_KEY+"&client_secret="+SECRET_KEY+"&grant_type=client_credentials")
.method("POST", body)
.addHeader("Content-Type", "application/json")
.addHeader("Accept", "application/json")
.build();
OkHttpClient HTTP_CLIENT = new OkHttpClient().newBuilder().build();
String response = null;
try {
//请求
response = HTTP_CLIENT.newCall(request).execute().body().string();
} catch (IOException e) {
log.error("ACCESS_TOKEN获取失败");
return null;
}
//刷新令牌以及更新令牌创建时间和过期时间
JSONObject jsonObject = JSON.parseObject(response);
ACCESS_TOKEN = jsonObject.getString("access_token");
REFRESH_TOKEN = jsonObject.getString("refresh_token");
CREATE_TIME = new Date();
EXPIRATION_TIME = new Date(Long.parseLong(jsonObject.getString("expires_in")) + CREATE_TIME.getTime());
return ACCESS_TOKEN;
}
}
配置application.yml
okhttp:
connect-timeout: 30 # 连接超时时间,单位通常为秒
read-timeout: 30 # 读取超时时间,单位通常为秒
write-timeout: 30 # 写入超时时间,单位通常为秒
max-idle-connections: 200 # 连接池中整体的空闲连接的最大数量
keep-alive-duration: 300 # 连接空闲时间最多为多少秒
wenxin:
apiKey: #你的apikey
secretKey: #你的secretkey
accessTokenUrl: https://aip.baidubce.com/oauth/2.0/token
ERNIE_Bot:
URL: https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-lite-8k#改成你用的
创建TestController.java
package com.pet.controller;
import com.alibaba.fastjson.JSON;
import com.pet.common.WenXinConfig;
import com.alibaba.fastjson.JSONObject;
import lombok.extern.slf4j.Slf4j;
import okhttp3.*;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;
import javax.annotation.Resource;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
@RestController
@Slf4j
@RequestMapping("/test")
public class TestController {
@Resource
private WenXinConfig wenXinConfig;
//历史对话,需要按照user,assistant
List<Map<String,String>> messages = new ArrayList<>();
/**
* 非流式问答
* @param question 用户的问题
* @return
* @throws IOException
*/
@PostMapping("/ask")
public String test1(String question) throws IOException {
if(question == null || question.equals("")){
return "请输入问题";
}
String responseJson = null;
//先获取令牌然后才能访问api
if (wenXinConfig.flushAccessToken() != null) {
HashMap<String, String> user = new HashMap<>();
user.put("role","user");
user.put("content",question);
messages.add(user);
String requestJson = constructRequestJson(1,0.95,0.8,1.0,false,messages);
RequestBody body = RequestBody.create(MediaType.parse("application/json"), requestJson);
Request request = new Request.Builder()
.url(wenXinConfig.ERNIE_Bot_4_0_URL + "?access_token=" + wenXinConfig.flushAccessToken())
.method("POST", body)
.addHeader("Content-Type", "application/json")
.build();
OkHttpClient HTTP_CLIENT = new OkHttpClient().newBuilder().build();
try {
responseJson = HTTP_CLIENT.newCall(request).execute().body().string();
//将回复的内容转为一个JSONObject
JSONObject responseObject = JSON.parseObject(responseJson);
//将回复的内容添加到消息中
HashMap<String, String> assistant = new HashMap<>();
assistant.put("role","assistant");
assistant.put("content",responseObject.getString("result"));
messages.add(assistant);
} catch (IOException e) {
log.error("网络有问题");
return "网络有问题,请稍后重试";
}
}
return responseJson;
}
/**
* 构造请求的请求参数
* @param userId
* @param temperature
* @param topP
* @param penaltyScore
* @param messages
* @return
*/
public String constructRequestJson(Integer userId,
Double temperature,
Double topP,
Double penaltyScore,
boolean stream,
List<Map<String, String>> messages) {
Map<String,Object> request = new HashMap<>();
request.put("user_id",userId.toString());
request.put("temperature",temperature);
request.put("top_p",topP);
request.put("penalty_score",penaltyScore);
request.put("stream",stream);
request.put("messages",messages);
System.out.println(JSON.toJSONString(request));
return JSON.toJSONString(request);
}
}
postman测试
前端 创建advice.vue (使用了elment-ui)
<template>
<div class="container">
<div class="input-group">
<input type="text" v-model="question" placeholder="请输入您的需求" class="input-field">
<el-button @click="ask" type="ptimary">获取建议</el-button>
<el-button @click="reset" type="success">重置</el-button>
</div>
<div class="input-group">
<div>建议: </div>
<el-input
type="textarea"
:autosize="{ minRows: 4, maxRows: 8}"
v-model="answer"
v-loading="loading">
</el-input>
</div>
</div>
</template>
<script>
export default {
name: 'Advice',
data() {
return {
question: '',
answer: '',
loading: false
}
},
methods: {
ask() {
if(this.question === '') {
this.$message('请输入需求');
}else{
this.loading = true
this.$axios.post(this.$httpUrl + '/test/ask?question=' + this.question).then(res => {
this.answer = res.data.result;
this.loading = false
}).catch(error => {
console.error('Error', error);
});
}
},
reset(){
this.question=""
this.answer=""
this.loading=false
}
}
}
</script>
<style scoped>
body {
font-family: Arial, sans-serif;
margin: 0;
padding: 0;
background-color: #f7f7f7;
}
.container {
max-width: 600px;
margin: 100px auto;
padding: 20px;
/* background-color: #fff;*/
border-radius: 5px;
box-shadow: 0 2px 5px rgba(0, 0, 0, 0.1);
}
.input-group {
display: flex;
align-items: center;
margin-bottom: 20px;
}
.input-group input[type="text"] {
flex: 1;
padding: 10px;
border: 1px solid #ccc;
border-radius: 3px;
font-size: 16px;
transition: border-color 0.3s ease;
resize: both; /* Allows resizing */
}
.input-group input[type="text2"] {
flex: 1;
padding: 10px;
border: 1px solid #ccc;
border-radius: 3px;
font-size: 16px;
transition: border-color 0.3s ease;
resize: both; /* Allows resizing */
}
.input-group button {
padding: 10px 20px;
border: none;
border-radius: 3px;
cursor: pointer;
font-size: 16px;
transition: background-color 0.3s ease;
margin-left: 5px;
}
.suggestion {
margin-top: 30px; /* Increase top margin */
padding: 20px; /* Increase padding */
border: 1px solid #ccc;
border-radius: 5px;
font-size: 20px;
}
</style>